package edu.stanford.nlp.classify; import java.util.*; import edu.stanford.nlp.ling.Datum; import edu.stanford.nlp.ling.RVFDatum; import edu.stanford.nlp.stats.*; import edu.stanford.nlp.util.CollectionValuedMap; import edu.stanford.nlp.util.Generics; /** * A simple k-NN classifier, with the options of using unit votes, or weighted votes (by * similarity value). Use the <code>KNNClassifierFactory</code> class to train and instantiate * a new classifier. * * NOTE: partially generified, waiting for final generification of classifiers package. * @author Eric Yeh * * @param <K> Class label type * @param <V> Feature vector dimension type */ public class KNNClassifier<K,V> implements Classifier<K, V> { /** * */ private static final long serialVersionUID = 7115357548209007944L; private boolean weightedVotes = false; // whether this is a weighted vote (by sim), or not private CollectionValuedMap<K, Counter<V>> instances = new CollectionValuedMap<>(); private Map<Counter<V>, K> classLookup = Generics.newHashMap(); private boolean l2Normalize = false; int k = 0; public Collection<K> labels() { return classLookup.values(); } protected KNNClassifier(int k, boolean weightedVotes, boolean l2Normalize) { this.k = k; this.weightedVotes = weightedVotes; this.l2Normalize = l2Normalize; } protected void addInstances(Collection<RVFDatum<K, V>> datums) { for (RVFDatum<K, V> datum : datums) { K label = datum.label(); Counter<V> vec = datum.asFeaturesCounter(); instances.add(label, vec); classLookup.put(vec, label); } } /** * NOTE: currently does not support standard Datums, only RVFDatums. */ public K classOf(Datum<K, V> example) { if (example instanceof RVFDatum<?,?>) { ClassicCounter<K> scores = scoresOf(example); return Counters.toSortedList(scores).get(0); } else { return null; } } /** * Given an instance to classify, scores and returns * score by class. * * NOTE: supports only RVFDatums */ public ClassicCounter<K> scoresOf(Datum<K, V> datum) { if (datum instanceof RVFDatum<?,?>) { RVFDatum<K, V> vec = (RVFDatum<K, V>) datum; if (l2Normalize) { ClassicCounter<V> featVec = new ClassicCounter<>(vec.asFeaturesCounter()); Counters.normalize(featVec); vec = new RVFDatum<>(featVec); } ClassicCounter<Counter<V>> scores = new ClassicCounter<>(); for (Counter<V> instance : instances.allValues()) { scores.setCount(instance, Counters.cosine(vec.asFeaturesCounter(), instance)); // set entry, for given instance and score } List<Counter<V>> sorted = Counters.toSortedList(scores); ClassicCounter<K> classScores = new ClassicCounter<>(); for (int i=0;i<k && i<sorted.size(); i++) { K label = classLookup.get(sorted.get(i)); double count= 1.0; if (weightedVotes) { count = scores.getCount(sorted.get(i)); } classScores.incrementCount(label, count); } return classScores; } else { return null; } } // Quick little sanity check public static void main(String[] args) { Collection<RVFDatum<String, String>> trainingInstances = new ArrayList<>(); { ClassicCounter<String> f1 = new ClassicCounter<>(); f1.setCount("humidity", 5.0); f1.setCount("temperature", 35.0); trainingInstances.add(new RVFDatum<>(f1, "rain")); } { ClassicCounter<String> f1 = new ClassicCounter<>(); f1.setCount("humidity", 4.0); f1.setCount("temperature", 32.0); trainingInstances.add(new RVFDatum<>(f1, "rain")); } { ClassicCounter<String> f1 = new ClassicCounter<>(); f1.setCount("humidity", 6.0); f1.setCount("temperature", 30.0); trainingInstances.add(new RVFDatum<>(f1, "rain")); } { ClassicCounter<String> f1 = new ClassicCounter<>(); f1.setCount("humidity", 2.0); f1.setCount("temperature", 33.0); trainingInstances.add(new RVFDatum<>(f1, "dry")); } { ClassicCounter<String> f1 = new ClassicCounter<>(); f1.setCount("humidity", 1.0); f1.setCount("temperature", 34.0); trainingInstances.add(new RVFDatum<>(f1, "dry")); } KNNClassifier<String, String> classifier = new KNNClassifierFactory<String, String>(3, false, true).train(trainingInstances); { ClassicCounter<String> f1 = new ClassicCounter<>(); f1.setCount("humidity", 2.0); f1.setCount("temperature", 33.0); RVFDatum<String, String> testVec = new RVFDatum<>(f1); System.out.println(classifier.scoresOf(testVec)); System.out.println(classifier.classOf(testVec)); } } }