package edu.stanford.nlp.classify;
import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.util.ArrayMap;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Index;
import java.util.Collection;
import java.util.Map;
import edu.stanford.nlp.util.logging.Redwood;
/**
* One vs All multiclass classifier
*
* @author Angel Chang
*/
public class OneVsAllClassifier<L,F> implements Classifier<L,F> {
private static final long serialVersionUID = -743792054415242776L;
private static final String POS_LABEL = "+1";
private static final String NEG_LABEL = "-1";
private static final Index<String> binaryIndex;
private static final int posIndex;
static {
binaryIndex = new HashIndex<>();
binaryIndex.add(POS_LABEL);
binaryIndex.add(NEG_LABEL);
posIndex = binaryIndex.indexOf(POS_LABEL);
}
private Index<F> featureIndex;
private Index<L> labelIndex;
private Map<L, Classifier<String,F>> binaryClassifiers;
private L defaultLabel;
private static final Redwood.RedwoodChannels logger = Redwood.channels(OneVsAllClassifier.class);
public OneVsAllClassifier(Index<F> featureIndex, Index<L> labelIndex) {
this(featureIndex, labelIndex, Generics.newHashMap(), null);
}
public OneVsAllClassifier(Index<F> featureIndex, Index<L> labelIndex, Map<L, Classifier<String, F>> binaryClassifiers) {
this(featureIndex, labelIndex, binaryClassifiers, null);
}
public OneVsAllClassifier(Index<F> featureIndex, Index<L> labelIndex, Map<L, Classifier<String, F>> binaryClassifiers, L defaultLabel) {
this.featureIndex = featureIndex;
this.labelIndex = labelIndex;
this.binaryClassifiers = binaryClassifiers;
this.defaultLabel = defaultLabel;
}
public void addBinaryClassifier(L label, Classifier<String,F> classifier) {
binaryClassifiers.put(label, classifier);
}
protected Classifier<String,F> getBinaryClassifier(L label)
{
return binaryClassifiers.get(label);
}
@Override
public L classOf(Datum<L, F> example) {
Counter<L> scores = scoresOf(example);
if (scores != null) {
return Counters.argmax(scores);
} else {
return defaultLabel;
}
}
@Override
public Counter<L> scoresOf(Datum<L, F> example) {
Counter<L> scores = new ClassicCounter<>();
for (L label:labelIndex) {
Map<L,String> posLabelMap = new ArrayMap<>();
posLabelMap.put(label, POS_LABEL);
Datum<String,F> binDatum = GeneralDataset.mapDatum(example, posLabelMap, NEG_LABEL);
Classifier<String,F> binaryClassifier = getBinaryClassifier(label);
Counter<String> binScores = binaryClassifier.scoresOf(binDatum);
double score = binScores.getCount(POS_LABEL);
scores.setCount(label, score);
}
return scores;
}
@Override
public Collection<L> labels() {
return labelIndex.objectsList();
}
public static <L,F> OneVsAllClassifier<L,F> train(ClassifierFactory<String,F, Classifier<String,F>> classifierFactory,
GeneralDataset<L, F> dataset) {
Index<L> labelIndex = dataset.labelIndex();
return train(classifierFactory, dataset, labelIndex.objectsList());
}
public static <L,F> OneVsAllClassifier<L,F> train(ClassifierFactory<String,F, Classifier<String,F>> classifierFactory,
GeneralDataset<L, F> dataset, Collection<L> trainLabels) {
Index<L> labelIndex = dataset.labelIndex();
Index<F> featureIndex = dataset.featureIndex();
Map<L, Classifier<String, F>> classifiers = Generics.newHashMap();
for (L label:trainLabels) {
int i = labelIndex.indexOf(label);
logger.info("Training " + label + " = " + i + ", posIndex = " + posIndex);
// Create training data for training this classifier
Map<L,String> posLabelMap = new ArrayMap<>();
posLabelMap.put(label, POS_LABEL);
GeneralDataset<String,F> binaryDataset = dataset.mapDataset(dataset, binaryIndex, posLabelMap, NEG_LABEL);
Classifier<String,F> binaryClassifier = classifierFactory.trainClassifier(binaryDataset);
classifiers.put(label, binaryClassifier);
}
OneVsAllClassifier<L,F> classifier = new OneVsAllClassifier<>(featureIndex, labelIndex, classifiers);
return classifier;
}
}