/** * Copyright 2013-2015 Pierre Merienne * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.github.pmerienne.trident.ml.classification; import com.github.pmerienne.trident.ml.util.MathUtil; /** * Passive-Aggresive binary classifier. * * @see Online Passive-Aggressive Algorithms * * Koby Crammer, Ofer Dekel, Joseph Keshet, Shai Shalev-Shwartz, Yoram * Singer; 7(Mar):551--585, 2006. * @author pmerienne * */ public class PAClassifier implements Classifier<Boolean> { private static final long serialVersionUID = -5163481593640555140L; private double[] weights; private Type type = Type.STANDARD; private Double aggressiveness = 0.001; public PAClassifier() { } public PAClassifier(Type type) { this.type = type; } public PAClassifier(Type type, Double aggressiveness) { this.type = type; this.aggressiveness = aggressiveness; } @Override public Boolean classify(double[] features) { if (this.weights == null) { this.init(features.length); } Double evaluation = MathUtil.dot(features, this.weights); Boolean prediction = evaluation >= 0 ? Boolean.TRUE : Boolean.FALSE; return prediction; } @Override public void update(Boolean expectedLabel, double[] features) { if (this.weights == null) { this.init(features.length); } Double expectedLabelAsInt = expectedLabel ? 1.0 : -1.0; double loss = Math.max(0.0, 1 - (expectedLabelAsInt * MathUtil.dot(this.weights, features))); double update = 0; if (Type.STANDARD.equals(this.type)) { update = loss / (1 + Math.pow(MathUtil.norm(features), 2)); } else if (Type.PA1.equals(this.type)) { update = Math.min(this.aggressiveness, loss / Math.pow(MathUtil.norm(features), 2)); } else if (Type.PA2.equals(this.type)) { update = loss / (Math.pow(MathUtil.norm(features), 2) + (1.0 / (2 * this.aggressiveness))); } double[] scaledFeatures = MathUtil.mult(features, update * expectedLabelAsInt); this.weights = MathUtil.add(this.weights, scaledFeatures); } protected void init(int featureSize) { // Init weights this.weights = new double[featureSize]; } @Override public void reset() { this.weights = null; } public double[] getWeights() { return weights; } public void setWeights(double[] weights) { this.weights = weights; } public Type getType() { return type; } public void setType(Type type) { this.type = type; } public Double getAggressiveness() { return aggressiveness; } public void setAggressiveness(Double aggressiveness) { this.aggressiveness = aggressiveness; } @Override public String toString() { return "PAClassifier [type=" + type + ", aggressiveness=" + aggressiveness + "]"; } public static enum Type { STANDARD, PA1, PA2; } }