/* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.classifier.discriminative;
import org.apache.mahout.math.CardinalityException;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Implementors of this class need to provide a way to train linear
* discriminative classifiers.
*
* As this is just the reference implementation we assume that the dataset fits
* into main memory - this should be the first thing to change when switching to
* Hadoop.
*/
public abstract class LinearTrainer {
private static final Logger log = LoggerFactory.getLogger(LinearTrainer.class);
/** The model to train. */
private final LinearModel model;
/**
* Initialize the trainer. Distance is initialized to cosine distance, all
* weights are represented through a dense vector.
*
*
* @param dimension
* number of expected features.
* @param threshold
* threshold to use for classification.
* @param init
* initial value of weight vector.
* @param initBias
* initial classification bias.
*/
protected LinearTrainer(int dimension, double threshold,
double init, double initBias) {
DenseVector initialWeights = new DenseVector(dimension);
initialWeights.assign(init);
this.model = new LinearModel(initialWeights, initBias, threshold);
}
/**
* Initializes training. Runs through all data points in the training set and
* updates the weight vector whenever a classification error occurs.
*
* Can be called multiple times.
*
* @param dataset
* the dataset to train on. Each column is treated as point.
* @param labelset
* the set of labels, one for each data point. If the cardinalities
* of data- and labelset do not match, a CardinalityException is
* thrown
*/
public void train(Vector labelset, Matrix dataset) throws TrainingException {
if (labelset.size() != dataset.columnSize()) {
throw new CardinalityException(labelset.size(), dataset.columnSize());
}
boolean converged = false;
int iteration = 0;
while (!converged) {
if (iteration > 1000) {
throw new TrainingException("Too many iterations needed to find hyperplane.");
}
converged = true;
int columnCount = dataset.columnSize();
for (int i = 0; i < columnCount; i++) {
Vector dataPoint = dataset.viewColumn(i);
log.debug("Training point: {}", dataPoint);
synchronized (this.model) {
boolean prediction = model.classify(dataPoint);
double label = labelset.get(i);
if (label <= 0 && prediction || label > 0 && !prediction) {
log.debug("updating");
converged = false;
update(label, dataPoint, this.model);
}
}
}
iteration++;
}
}
/**
* Retrieves the trained model if called after train, otherwise the raw model.
*/
public LinearModel getModel() {
return this.model;
}
/**
* Implement this method to match your training strategy.
*
* @param model
* the model to update.
* @param label
* the target label of the wrongly classified data point.
* @param dataPoint
* the data point that was classified incorrectly.
*/
protected abstract void update(double label, Vector dataPoint, LinearModel model);
}