package cc.mallet.cluster.neighbor_evaluator;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import cc.mallet.classify.Classifier;
/**
* A {@link NeighborEvaluator} that is backed by a {@link
* Classifier}. The score for a {@link Neighbor} is the Classifier's
* predicted value for the label corresponding to <code>scoringLabel</code>.
*
* @author "Aron Culotta" <culotta@degas.cs.umass.edu>
* @version 1.0
* @since 1.0
* @see NeighborEvaluator
*/
public class ClassifyingNeighborEvaluator implements NeighborEvaluator, Serializable {
/**
* The Classifier used to assign a score to each {@link Neighbor}.
*/
Classifier classifier;
/**
* The label corresponding to a positive instance (e.g. "YES").
*/
String scoringLabel;
/**
*
* @param classifier The Classifier used to assign a score to each {@link Neighbor}.
* @param scoringLabel The label corresponding to a positive instance (e.g. "YES").
* @return
*/
public ClassifyingNeighborEvaluator (Classifier classifier,
String scoringLabel) {
this.classifier = classifier;
this.scoringLabel = scoringLabel;
}
/**
*
* @return The classifier.
*/
public Classifier getClassifier () { return classifier; }
public double evaluate (Neighbor neighbor) {
return classifier.classify(neighbor).getLabelVector().value(scoringLabel);
}
public double[] evaluate (Neighbor[] neighbors) {
double[] scores = new double[neighbors.length];
for (int i = 0; i < neighbors.length; i++)
scores[i] = evaluate(neighbors[i]);
return scores;
}
public void reset () {
}
public String toString () {
return "class=" + this.getClass().getName() +
" classifier=" + classifier.getClass().getName() +
" scoringLabel=" + scoringLabel;
}
// SERIALIZATION
private static final long serialVersionUID = 1;
private static final int CURRENT_SERIAL_VERSION = 1;
private void writeObject (ObjectOutputStream out) throws IOException {
out.defaultWriteObject ();
out.writeInt (CURRENT_SERIAL_VERSION);
}
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException {
in.defaultReadObject ();
int version = in.readInt ();
}
}