/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
/*
* SGD.java
* Copyright (C) 2009-2012 University of Waikato, Hamilton, New Zealand
*
*/
package weka.classifiers.functions;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.RandomizableClassifier;
import weka.classifiers.UpdateableClassifier;
import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.Aggregateable;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Normalize;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;
/**
<!-- globalinfo-start -->
* Implements stochastic gradient descent for learning
* various linear models (binary class SVM, binary class logistic regression,
* squared loss, Huber loss and epsilon-insensitive loss linear regression).
* Globally replaces all missing values and transforms nominal attributes into
* binary ones. It also normalizes all attributes, so the coefficients in the
* output are based on the normalized data.<br/>
* For numeric class attributes, the squared, Huber or epsilon-insensitve loss
* function must be used. Epsilon-insensitive and Huber loss may require a much
* higher learning rate.
* <p/>
<!-- globalinfo-end -->
*
<!-- options-start -->
* Valid options are:
* <p/>
*
* <pre>
* -F
* Set the loss function to minimize. 0 = hinge loss (SVM), 1 = log loss (logistic regression),
* 2 = squared loss (regression).
* (default = 0)
* </pre>
*
* <pre>
* -L
* The learning rate. If normalization is
* turned off (as it is automatically for streaming data), then the
* default learning rate will need to be reduced (try 0.0001).
* (default = 0.01).
* </pre>
*
* <pre>
* -R <double>
* The lambda regularization constant (default = 0.0001)
* </pre>
*
* <pre>
* -E <integer>
* The number of epochs to perform (batch learning only, default = 500)
* </pre>
*
* <pre>
* -C <double>
* The epsilon threshold (epsilon-insenstive and Huber loss only, default = 1e-3)
* </pre>
*
* <pre>
* -N
* Don't normalize the data
* </pre>
*
* <pre>
* -M
* Don't replace missing values
* </pre>
*
<!-- options-end -->
*
* @author Eibe Frank (eibe{[at]}cs{[dot]}waikato{[dot]}ac{[dot]}nz)
* @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
* @version $Revision: 9785 $
*
*/
public class SGD extends RandomizableClassifier implements
UpdateableClassifier, OptionHandler, Aggregateable<SGD> {
/** For serialization */
private static final long serialVersionUID = -3732968666673530290L;
/** Replace missing values */
protected ReplaceMissingValues m_replaceMissing;
/**
* Convert nominal attributes to numerically coded binary ones. Uses
* supervised NominalToBinary in the batch learning case
*/
protected Filter m_nominalToBinary;
/** Normalize the training data */
protected Normalize m_normalize;
/** The regularization parameter */
protected double m_lambda = 0.0001;
/** The learning rate */
protected double m_learningRate = 0.01;
/** Stores the weights (+ bias in the last element) */
protected double[] m_weights;
/** The epsilon parameter for epsilon insensitive and Huber loss */
protected double m_epsilon = 1e-3;
/** Holds the current iteration number */
protected double m_t;
/** The number of training instances */
protected double m_numInstances;
/**
* The number of epochs to perform (batch learning). Total iterations is
* m_epochs * num instances
*/
protected int m_epochs = 500;
/**
* Turn off normalization of the input data. This option gets forced for
* incremental training.
*/
protected boolean m_dontNormalize = false;
/**
* Turn off global replacement of missing values. Missing values will be
* ignored instead. This option gets forced for incremental training.
*/
protected boolean m_dontReplaceMissing = false;
/** Holds the header of the training data */
protected Instances m_data;
/**
* Returns default capabilities of the classifier.
*
* @return the capabilities of this classifier
*/
@Override
public Capabilities getCapabilities() {
Capabilities result = super.getCapabilities();
result.disableAll();
// attributes
result.enable(Capability.NOMINAL_ATTRIBUTES);
result.enable(Capability.NUMERIC_ATTRIBUTES);
result.enable(Capability.MISSING_VALUES);
// class
if (m_loss == SQUAREDLOSS || m_loss == EPSILON_INSENSITIVE
|| m_loss == HUBER)
result.enable(Capability.NUMERIC_CLASS);
else
result.enable(Capability.BINARY_CLASS);
result.enable(Capability.MISSING_CLASS_VALUES);
// instances
result.setMinimumNumberInstances(0);
return result;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String epsilonTipText() {
return "The epsilon threshold for epsilon insensitive and Huber "
+ "loss. An error with absolute value less that this "
+ "threshold has loss of 0 for epsilon insensitive loss. "
+ "For Huber loss this is the boundary between the quadratic "
+ "and linear parts of the loss function.";
}
/**
* Set the epsilon threshold on the error for epsilon insensitive and Huber
* loss functions
*
* @param e the value of epsilon to use
*/
public void setEpsilon(double e) {
m_epsilon = e;
}
/**
* Get the epsilon threshold on the error for epsilon insensitive and Huber
* loss functions
*
* @return the value of epsilon to use
*/
public double getEpsilon() {
return m_epsilon;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String lambdaTipText() {
return "The regularization constant. (default = 0.0001)";
}
/**
* Set the value of lambda to use
*
* @param lambda the value of lambda to use
*/
public void setLambda(double lambda) {
m_lambda = lambda;
}
/**
* Get the current value of lambda
*
* @return the current value of lambda
*/
public double getLambda() {
return m_lambda;
}
/**
* Set the learning rate.
*
* @param lr the learning rate to use.
*/
public void setLearningRate(double lr) {
m_learningRate = lr;
}
/**
* Get the learning rate.
*
* @return the learning rate
*/
public double getLearningRate() {
return m_learningRate;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String learningRateTipText() {
return "The learning rate. If normalization is turned off "
+ "(as it is automatically for streaming data), then"
+ "the default learning rate will need to be reduced ("
+ "try 0.0001).";
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String epochsTipText() {
return "The number of epochs to perform (batch learning). "
+ "The total number of iterations is epochs * num" + " instances.";
}
/**
* Set the number of epochs to use
*
* @param e the number of epochs to use
*/
public void setEpochs(int e) {
m_epochs = e;
}
/**
* Get current number of epochs
*
* @return the current number of epochs
*/
public int getEpochs() {
return m_epochs;
}
/**
* Turn normalization off/on.
*
* @param m true if normalization is to be disabled.
*/
public void setDontNormalize(boolean m) {
m_dontNormalize = m;
}
/**
* Get whether normalization has been turned off.
*
* @return true if normalization has been disabled.
*/
public boolean getDontNormalize() {
return m_dontNormalize;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String dontNormalizeTipText() {
return "Turn normalization off";
}
/**
* Turn global replacement of missing values off/on. If turned off, then
* missing values are effectively ignored.
*
* @param m true if global replacement of missing values is to be turned off.
*/
public void setDontReplaceMissing(boolean m) {
m_dontReplaceMissing = m;
}
/**
* Get whether global replacement of missing values has been disabled.
*
* @return true if global replacement of missing values has been turned off
*/
public boolean getDontReplaceMissing() {
return m_dontReplaceMissing;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String dontReplaceMissingTipText() {
return "Turn off global replacement of missing values";
}
/**
* Set the loss function to use.
*
* @param function the loss function to use.
*/
public void setLossFunction(SelectedTag function) {
if (function.getTags() == TAGS_SELECTION) {
m_loss = function.getSelectedTag().getID();
}
}
/**
* Get the current loss function.
*
* @return the current loss function.
*/
public SelectedTag getLossFunction() {
return new SelectedTag(m_loss, TAGS_SELECTION);
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String lossFunctionTipText() {
return "The loss function to use. Hinge loss (SVM), "
+ "log loss (logistic regression) or " + "squared loss (regression).";
}
/**
* Returns an enumeration describing the available options.
*
* @return an enumeration of all the available options.
*/
@Override
public Enumeration<Option> listOptions() {
Vector<Option> newVector = new Vector<Option>();
newVector.add(new Option("\tSet the loss function to minimize.\n\t0 = "
+ "hinge loss (SVM), 1 = log loss (logistic regression),\n\t"
+ "2 = squared loss (regression), 3 = epsilon insensitive loss (regression)," +
"\n\t4 = Huber loss (regression).\n\t(default = 0)", "F", 1, "-F"));
newVector
.add(new Option(
"\tThe learning rate. If normalization is\n"
+ "\tturned off (as it is automatically for streaming data), then the\n\t"
+ "default learning rate will need to be reduced "
+ "(try 0.0001).\n\t(default = 0.01).", "L", 1, "-L"));
newVector.add(new Option("\tThe lambda regularization constant "
+ "(default = 0.0001)", "R", 1, "-R <double>"));
newVector.add(new Option("\tThe number of epochs to perform ("
+ "batch learning only, default = 500)", "E", 1, "-E <integer>"));
newVector.add(new Option("\tThe epsilon threshold ("
+ "epsilon-insenstive and Huber loss only, default = 1e-3)", "C", 1,
"-C <double>"));
newVector.add(new Option("\tDon't normalize the data", "N", 0, "-N"));
newVector.add(new Option("\tDon't replace missing values", "M", 0, "-M"));
return newVector.elements();
}
/**
*
* Parses a given list of options.
* <p/>
*
<!-- options-start -->
* Valid options are:
* <p/>
*
* <pre>
* -F
* Set the loss function to minimize. 0 = hinge loss (SVM), 1 = log loss (logistic regression),
* 2 = squared loss (regression).
* (default = 0)
* </pre>
*
* <pre>
* -L
* The learning rate. If normalization is
* turned off (as it is automatically for streaming data), then the
* default learning rate will need to be reduced (try 0.0001).
* (default = 0.01).
* </pre>
*
* <pre>
* -R <double>
* The lambda regularization constant (default = 0.0001)
* </pre>
*
* <pre>
* -E <integer>
* The number of epochs to perform (batch learning only, default = 500)
* </pre>
*
* <pre>
* -C <double>
* The epsilon threshold (epsilon-insenstive and Huber loss only, default = 1e-3)
* </pre>
*
* <pre>
* -N
* Don't normalize the data
* </pre>
*
* <pre>
* -M
* Don't replace missing values
* </pre>
*
<!-- options-end -->
*
* @param options the list of options as an array of strings
* @throws Exception if an option is not supported
*/
@Override
public void setOptions(String[] options) throws Exception {
reset();
super.setOptions(options);
String lossString = Utils.getOption('F', options);
if (lossString.length() != 0) {
setLossFunction(new SelectedTag(Integer.parseInt(lossString),
TAGS_SELECTION));
}
String lambdaString = Utils.getOption('R', options);
if (lambdaString.length() > 0) {
setLambda(Double.parseDouble(lambdaString));
}
String learningRateString = Utils.getOption('L', options);
if (learningRateString.length() > 0) {
setLearningRate(Double.parseDouble(learningRateString));
}
String epochsString = Utils.getOption("E", options);
if (epochsString.length() > 0) {
setEpochs(Integer.parseInt(epochsString));
}
String epsilonString = Utils.getOption("C", options);
if (epsilonString.length() > 0) {
setEpsilon(Double.parseDouble(epsilonString));
}
setDontNormalize(Utils.getFlag("N", options));
setDontReplaceMissing(Utils.getFlag('M', options));
}
/**
* Gets the current settings of the classifier.
*
* @return an array of strings suitable for passing to setOptions
*/
@Override
public String[] getOptions() {
ArrayList<String> options = new ArrayList<String>();
options.add("-F");
options.add("" + getLossFunction().getSelectedTag().getID());
options.add("-L");
options.add("" + getLearningRate());
options.add("-R");
options.add("" + getLambda());
options.add("-E");
options.add("" + getEpochs());
options.add("-C");
options.add("" + getEpsilon());
if (getDontNormalize()) {
options.add("-N");
}
if (getDontReplaceMissing()) {
options.add("-M");
}
return options.toArray(new String[1]);
}
/**
* Returns a string describing classifier
*
* @return a description suitable for displaying in the explorer/experimenter
* gui
*/
public String globalInfo() {
return "Implements stochastic gradient descent for learning"
+ " various linear models (binary class SVM, binary class"
+ " logistic regression, squared loss, Huber loss and "
+ "epsilon-insensitive loss linear regression)."
+ " Globally replaces all missing values and transforms nominal"
+ " attributes into binary ones. It also normalizes all attributes,"
+ " so the coefficients in the output are based on the normalized"
+ " data.\n" + "For numeric class attributes, the squared, Huber or "
+ "epsilon-insensitve loss function must be used. Epsilon-insensitive "
+ "and Huber loss may require a much higher learning rate.";
}
/**
* Reset the classifier.
*/
public void reset() {
m_t = 1;
m_weights = null;
}
/**
* Method for building the classifier.
*
* @param data the set of training instances.
* @throws Exception if the classifier can't be built successfully.
*/
@Override
public void buildClassifier(Instances data) throws Exception {
reset();
// can classifier handle the data?
getCapabilities().testWithFail(data);
data = new Instances(data);
data.deleteWithMissingClass();
if (data.numInstances() > 0 && !m_dontReplaceMissing) {
m_replaceMissing = new ReplaceMissingValues();
m_replaceMissing.setInputFormat(data);
data = Filter.useFilter(data, m_replaceMissing);
}
// check for only numeric attributes
boolean onlyNumeric = true;
for (int i = 0; i < data.numAttributes(); i++) {
if (i != data.classIndex()) {
if (!data.attribute(i).isNumeric()) {
onlyNumeric = false;
break;
}
}
}
if (!onlyNumeric) {
if (data.numInstances() > 0) {
m_nominalToBinary = new weka.filters.supervised.attribute.NominalToBinary();
} else {
m_nominalToBinary = new weka.filters.unsupervised.attribute.NominalToBinary();
}
m_nominalToBinary.setInputFormat(data);
data = Filter.useFilter(data, m_nominalToBinary);
}
if (!m_dontNormalize && data.numInstances() > 0) {
m_normalize = new Normalize();
m_normalize.setInputFormat(data);
data = Filter.useFilter(data, m_normalize);
}
m_numInstances = data.numInstances();
m_weights = new double[data.numAttributes() + 1];
m_data = new Instances(data, 0);
if (data.numInstances() > 0) {
data.randomize(new Random(getSeed())); // randomize the data
train(data);
}
}
/** the hinge loss function. */
public static final int HINGE = 0;
/** the log loss function. */
public static final int LOGLOSS = 1;
/** the squared loss function. */
public static final int SQUAREDLOSS = 2;
/** The epsilon insensitive loss function */
public static final int EPSILON_INSENSITIVE = 3;
/** The Huber loss function */
public static final int HUBER = 4;
/** The current loss function to minimize */
protected int m_loss = HINGE;
/** Loss functions to choose from */
public static final Tag[] TAGS_SELECTION = {
new Tag(HINGE, "Hinge loss (SVM)"),
new Tag(LOGLOSS, "Log loss (logistic regression)"),
new Tag(SQUAREDLOSS, "Squared loss (regression)"),
new Tag(EPSILON_INSENSITIVE, "Epsilon-insensitive loss (SVM regression)"),
new Tag(HUBER, "Huber loss (robust regression)") };
protected double dloss(double z) {
if (m_loss == HINGE) {
return (z < 1) ? 1 : 0;
}
if (m_loss == LOGLOSS) {
// log loss
if (z < 0) {
return 1.0 / (Math.exp(z) + 1.0);
} else {
double t = Math.exp(-z);
return t / (t + 1);
}
}
if (m_loss == EPSILON_INSENSITIVE) {
if (z > m_epsilon) {
return 1.0;
}
if (-z > m_epsilon) {
return -1.0;
}
return 0;
}
if (m_loss == HUBER) {
if (Math.abs(z) <= m_epsilon) {
return z;
} else if (z > 0.0) {
return m_epsilon;
} else {
return -m_epsilon;
}
}
// squared loss
return z;
}
private void train(Instances data) throws Exception {
for (int e = 0; e < m_epochs; e++) {
for (int i = 0; i < data.numInstances(); i++) {
updateClassifier(data.instance(i), false);
}
}
}
protected static double dotProd(Instance inst1, double[] weights,
int classIndex) {
double result = 0;
int n1 = inst1.numValues();
int n2 = weights.length - 1;
for (int p1 = 0, p2 = 0; p1 < n1 && p2 < n2;) {
int ind1 = inst1.index(p1);
int ind2 = p2;
if (ind1 == ind2) {
if (ind1 != classIndex && !inst1.isMissingSparse(p1)) {
result += inst1.valueSparse(p1) * weights[p2];
}
p1++;
p2++;
} else if (ind1 > ind2) {
p2++;
} else {
p1++;
}
}
return (result);
}
/**
* Updates the classifier with the given instance.
*
* @param instance the new training instance to include in the model
* @param filter true if the instance should pass through any of the filters
* set up in buildClassifier(). When batch training buildClassifier()
* already batch filters all training instances so don't need to
* filter them again here.
* @exception Exception if the instance could not be incorporated in the
* model.
*/
protected void updateClassifier(Instance instance, boolean filter)
throws Exception {
if (!instance.classIsMissing()) {
if (filter) {
if (m_replaceMissing != null) {
m_replaceMissing.input(instance);
instance = m_replaceMissing.output();
}
if (m_nominalToBinary != null) {
m_nominalToBinary.input(instance);
instance = m_nominalToBinary.output();
}
if (m_normalize != null) {
m_normalize.input(instance);
instance = m_normalize.output();
}
}
double wx = dotProd(instance, m_weights, instance.classIndex());
double y;
double z;
if (instance.classAttribute().isNominal()) {
y = (instance.classValue() == 0) ? -1 : 1;
z = y * (wx + m_weights[m_weights.length - 1]);
} else {
y = instance.classValue();
z = y - (wx + m_weights[m_weights.length - 1]);
y = 1;
}
// Compute multiplier for weight decay
double multiplier = 1.0;
if (m_numInstances == 0) {
multiplier = 1.0 - (m_learningRate * m_lambda) / m_t;
} else {
multiplier = 1.0 - (m_learningRate * m_lambda) / m_numInstances;
}
for (int i = 0; i < m_weights.length - 1; i++) {
m_weights[i] *= multiplier;
}
// Only need to do the following if the loss is non-zero
// if (m_loss != HINGE || (z < 1)) {
if (m_loss == SQUAREDLOSS || m_loss == LOGLOSS || m_loss == HUBER
|| (m_loss == HINGE && (z < 1))
|| (m_loss == EPSILON_INSENSITIVE && Math.abs(z) > m_epsilon)) {
// Compute Factor for updates
double factor = m_learningRate * y * dloss(z);
// Update coefficients for attributes
int n1 = instance.numValues();
for (int p1 = 0; p1 < n1; p1++) {
int indS = instance.index(p1);
if (indS != instance.classIndex() && !instance.isMissingSparse(p1)) {
m_weights[indS] += factor * instance.valueSparse(p1);
}
}
// update the bias
m_weights[m_weights.length - 1] += factor;
}
m_t++;
}
}
/**
* Updates the classifier with the given instance.
*
* @param instance the new training instance to include in the model
* @exception Exception if the instance could not be incorporated in the
* model.
*/
@Override
public void updateClassifier(Instance instance) throws Exception {
updateClassifier(instance, true);
}
/**
* Computes the distribution for a given instance
*
* @param instance the instance for which distribution is computed
* @return the distribution
* @throws Exception if the distribution can't be computed successfully
*/
@Override
public double[] distributionForInstance(Instance inst) throws Exception {
double[] result = (inst.classAttribute().isNominal()) ? new double[2]
: new double[1];
if (m_replaceMissing != null) {
m_replaceMissing.input(inst);
inst = m_replaceMissing.output();
}
if (m_nominalToBinary != null) {
m_nominalToBinary.input(inst);
inst = m_nominalToBinary.output();
}
if (m_normalize != null) {
m_normalize.input(inst);
inst = m_normalize.output();
}
double wx = dotProd(inst, m_weights, inst.classIndex());// * m_wScale;
double z = (wx + m_weights[m_weights.length - 1]);
if (inst.classAttribute().isNumeric()) {
result[0] = z;
return result;
}
if (z <= 0) {
// z = 0;
if (m_loss == LOGLOSS) {
result[0] = 1.0 / (1.0 + Math.exp(z));
result[1] = 1.0 - result[0];
} else {
result[0] = 1;
}
} else {
if (m_loss == LOGLOSS) {
result[1] = 1.0 / (1.0 + Math.exp(-z));
result[0] = 1.0 - result[1];
} else {
result[1] = 1;
}
}
return result;
}
public double[] getWeights() {
return m_weights;
}
/**
* Prints out the classifier.
*
* @return a description of the classifier as a string
*/
@Override
public String toString() {
if (m_weights == null) {
return "SGD: No model built yet.\n";
}
StringBuffer buff = new StringBuffer();
buff.append("Loss function: ");
if (m_loss == HINGE) {
buff.append("Hinge loss (SVM)\n\n");
} else if (m_loss == LOGLOSS) {
buff.append("Log loss (logistic regression)\n\n");
} else {
buff.append("Squared loss (linear regression)\n\n");
}
buff.append(m_data.classAttribute().name() + " = \n\n");
int printed = 0;
for (int i = 0; i < m_weights.length - 1; i++) {
if (i != m_data.classIndex()) {
if (printed > 0) {
buff.append(" + ");
} else {
buff.append(" ");
}
buff.append(Utils.doubleToString(m_weights[i], 12, 4) + " "
+ ((m_normalize != null) ? "(normalized) " : "")
+ m_data.attribute(i).name() + "\n");
printed++;
}
}
if (m_weights[m_weights.length - 1] > 0) {
buff.append(" + "
+ Utils.doubleToString(m_weights[m_weights.length - 1], 12, 4));
} else {
buff.append(" - "
+ Utils.doubleToString(-m_weights[m_weights.length - 1], 12, 4));
}
return buff.toString();
}
/**
* Returns the revision string.
*
* @return the revision
*/
@Override
public String getRevision() {
return RevisionUtils.extract("$Revision: 9785 $");
}
protected int m_numModels = 0;
/**
* Aggregate an object with this one
*
* @param toAggregate the object to aggregate
* @return the result of aggregation
* @throws Exception if the supplied object can't be aggregated for some
* reason
*/
@Override
public SGD aggregate(SGD toAggregate) throws Exception {
if (m_weights == null) {
throw new Exception("No model built yet, can't aggregate");
}
if (!m_data.equalHeaders(toAggregate.m_data)) {
throw new Exception("Can't aggregate - data headers dont match: "
+ m_data.equalHeadersMsg(toAggregate.m_data));
}
if (m_weights.length != toAggregate.getWeights().length) {
throw new Exception(
"Can't aggregate - SDG to aggregate has weight vector "
+ "that differs in length from ours.");
}
for (int i = 0; i < m_weights.length; i++) {
m_weights[i] += toAggregate.getWeights()[i];
}
m_numModels++;
return this;
}
/**
* Call to complete the aggregation process. Allows implementers to do any
* final processing based on how many objects were aggregated.
*
* @throws Exception if the aggregation can't be finalized for some reason
*/
@Override
public void finalizeAggregation() throws Exception {
if (m_numModels == 0) {
throw new Exception("Unable to finalize aggregation - " +
"haven't seen any models to aggregate");
}
for (int i = 0; i < m_weights.length; i++) {
m_weights[i] /= (m_numModels + 1); // plus one for us
}
// aggregation complete
m_numModels = 0;
}
/**
* Main method for testing this class.
*/
public static void main(String[] args) {
runClassifier(new SGD(), args);
}
}