package edu.stanford.nlp.classify;
import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.HashIndex;
import java.lang.ref.Reference;
import java.util.Collection;
import java.util.List;
/**
* Shared methods for training a {@link LinearClassifier}.
* Inheriting classes need to implement the
* <code>trainWeights</code> method.
*
* @author Dan Klein
*
* @author Sarah Spikes (sdspikes@cs.stanford.edu) (Templatization)
*
* @param <L> The type of the labels in the Dataset and Datum
* @param <F> The type of the features in the Dataset and Datum
*/
public abstract class AbstractLinearClassifierFactory<L, F> implements ClassifierFactory<L, F, Classifier<L, F>> {
private static final long serialVersionUID = 1L;
Index<L> labelIndex = new HashIndex<>();
Index<F> featureIndex = new HashIndex<>();
public AbstractLinearClassifierFactory() {
}
int numFeatures() {
return featureIndex.size();
}
int numClasses() {
return labelIndex.size();
}
protected abstract double[][] trainWeights(GeneralDataset<L, F> dataset) ;
/**
* Takes a {@link Collection} of {@link Datum} objects and gives you back a
* {@link Classifier} trained on it.
*
* @param examples {@link Collection} of {@link Datum} objects to train the
* classifier on
* @return A {@link Classifier} trained on it.
*/
public LinearClassifier<L, F> trainClassifier(Collection<Datum<L, F>> examples) {
Dataset<L, F> dataset = new Dataset<>();
dataset.addAll(examples);
return trainClassifier(dataset);
}
/**
* Takes a {@link Reference} to a {@link Collection} of {@link Datum}
* objects and gives you back a {@link Classifier} trained on them
*
* @param ref {@link Reference} to a {@link Collection} of {@link
* Datum} objects to train the classifier on
* @return A Classifier trained on a collection of Datum
*/
public LinearClassifier<L, F> trainClassifier(Reference<? extends Collection<Datum<L, F>>> ref) {
Collection<Datum<L, F>> examples = ref.get();
return trainClassifier(examples);
}
/**
* Trains a {@link Classifier} on a {@link Dataset}.
*
* @return A {@link Classifier} trained on the data.
*/
public LinearClassifier<L, F> trainClassifier(GeneralDataset<L, F> data) {
labelIndex = data.labelIndex();
featureIndex = data.featureIndex();
double[][] weights = trainWeights(data);
return new LinearClassifier<>(weights, featureIndex, labelIndex);
}
}