/**
* Copyright 2013-2015 Pierre Merienne
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.github.pmerienne.trident.ml.classification;
import com.github.pmerienne.trident.ml.util.MathUtil;
public class MultiClassPAClassifier implements Classifier<Integer> {
private static final long serialVersionUID = -5163481593640555140L;
private double[][] weightVectors;
private Type type = Type.STANDARD;
private Double aggressiveness = 0.001;
private Integer nbClasses;
public MultiClassPAClassifier() {
}
public MultiClassPAClassifier(Integer nbClasses) {
this.nbClasses = nbClasses;
}
public MultiClassPAClassifier(Integer nbClasses, Type type) {
this.nbClasses = nbClasses;
this.type = type;
}
public MultiClassPAClassifier(Integer nbClasses, Type type, Double aggressiveness) {
this.nbClasses = nbClasses;
this.type = type;
this.aggressiveness = aggressiveness;
}
@Override
public Integer classify(double[] features) {
if (this.weightVectors == null) {
this.initWeightVectors(features.length);
}
Integer prediction = null;
Double highestScore = -Double.MAX_VALUE;
Double currentClassScore;
double[] currentWeightVector;
for (int i = 0; i < this.weightVectors.length; i++) {
currentWeightVector = this.weightVectors[i];
currentClassScore = MathUtil.dot(currentWeightVector, features);
if (currentClassScore > highestScore) {
prediction = i;
highestScore = currentClassScore;
}
}
return prediction;
}
@Override
public void update(Integer expectedLabel, double[] features) {
Integer predictedLabel = this.classify(features);
// lagrange multiplier
double loss = 1.0 - (MathUtil.dot(this.weightVectors[expectedLabel], features) - MathUtil.dot(this.weightVectors[predictedLabel], features));
double tau = 0.0;
if (Type.STANDARD.equals(this.type)) {
tau = loss / (1 + 2 * Math.pow(MathUtil.norm(features), 2));
} else if (Type.PA1.equals(this.type)) {
tau = Math.min(this.aggressiveness / 2, loss / (2 * Math.pow(MathUtil.norm(features), 2)));
} else if (Type.PA2.equals(this.type)) {
tau = 0.5 * (loss / (Math.pow(MathUtil.norm(features), 2) + (1 / (2 * this.aggressiveness))));
}
double[] currentWeightVector;
for (int i = 0; i < this.weightVectors.length; i++) {
currentWeightVector = this.weightVectors[i];
if (i != expectedLabel && i != predictedLabel) {
// No change
} else if (i == expectedLabel) {
this.weightVectors[i] = MathUtil.add(currentWeightVector, MathUtil.mult(features, tau));
} else if (i == predictedLabel) {
this.weightVectors[i] = MathUtil.subtract(currentWeightVector, MathUtil.mult(features, tau));
}
}
}
private void initWeightVectors(int featureSize) {
this.weightVectors = new double[this.nbClasses][featureSize];
for (int i = 0; i < this.nbClasses; i++) {
for (int j = 0; j < featureSize; j++) {
this.weightVectors[i][j] = 0.0;
}
}
}
@Override
public void reset() {
this.weightVectors = null;
}
public double[][] getWeightVectors() {
return weightVectors;
}
public void setWeightVectors(double[][] weightVectors) {
this.weightVectors = weightVectors;
}
public Type getType() {
return type;
}
public void setType(Type type) {
this.type = type;
}
public Double getAggressiveness() {
return aggressiveness;
}
public void setAggressiveness(Double aggressiveness) {
this.aggressiveness = aggressiveness;
}
public Integer getNbClasses() {
return nbClasses;
}
public void setNbClasses(Integer nbClasses) {
this.nbClasses = nbClasses;
}
@Override
public String toString() {
return "MultiClassPAClassifier [nbClasses=" + nbClasses + ", type=" + type + ", aggressiveness=" + aggressiveness + "]";
}
public static enum Type {
STANDARD, PA1, PA2;
}
}