package edu.stanford.nlp.classify; import java.util.*; import edu.stanford.nlp.ling.*; import edu.stanford.nlp.stats.*; import edu.stanford.nlp.util.CollectionValuedMap; /** * This constructs trained <code>KNNClassifier</code> objects, given * sets of RVFDatums, or Counters (dimensions are identified by the keys). */ public class KNNClassifierFactory<K, V> { private int k = 0; private boolean weightedVotes = false; private boolean l2NormalizeVectors = false; /** * Creates a new factory that generates K-NN classifiers with the given k-value, and * if the votes are weighted by their similarity score, or unit value. */ public KNNClassifierFactory(int k, boolean weightedVotes, boolean l2NormalizeVectors) { this.k = k; this.weightedVotes = weightedVotes; this.l2NormalizeVectors = l2NormalizeVectors; } /** * Given a set of labeled RVFDatums, treats each as an instance vector of that * label and adds it to the examples used for classification. * * NOTE: l2NormalizeVectors is NOT applied here. */ public KNNClassifier<K,V> train(Collection<RVFDatum<K, V>> instances) { KNNClassifier<K, V> classifier = new KNNClassifier<>(k, weightedVotes, l2NormalizeVectors); classifier.addInstances(instances); return classifier; } /** * Given a set of vectors, and a mapping from each vector to its class label, * generates the sets of instances used to perform classifications and returns * the corresponding K-NN classifier. * * NOTE: if l2NormalizeVectors is T, creates a copy and applies L2Normalize to it. */ public KNNClassifier<K,V> train(Collection<Counter<V>> vectors, Map<V, K> labelMap) { KNNClassifier<K, V> classifier = new KNNClassifier<>(k, weightedVotes, l2NormalizeVectors); Collection<RVFDatum<K, V>> instances = new ArrayList<>(); for (Counter<V> vector : vectors) { K label = labelMap.get(vector); RVFDatum<K, V> datum; if (l2NormalizeVectors) { datum = new RVFDatum<>(Counters.L2Normalize(new ClassicCounter<>(vector)), label); } else { datum = new RVFDatum<>(vector, label); } instances.add(datum); } classifier.addInstances(instances); return classifier; } /** * Given a CollectionValued Map of vectors, treats outer key as label for each * set of inner vectors. * NOTE: if l2NormalizeVectors is T, creates a copy of each vector and applies * l2Normalize to it. */ public KNNClassifier<K,V> train(CollectionValuedMap<K, Counter<V>> vecBag) { KNNClassifier<K, V> classifier = new KNNClassifier<>(k, weightedVotes, l2NormalizeVectors); Collection<RVFDatum<K, V>> instances = new ArrayList<>(); for (K label : vecBag.keySet()) { RVFDatum<K, V> datum; for (Counter<V> vector : vecBag.get(label)) { if (l2NormalizeVectors) { datum = new RVFDatum<>(Counters.L2Normalize(new ClassicCounter<>(vector)), label); } else { datum = new RVFDatum<>(vector, label); } instances.add(datum); } } classifier.addInstances(instances); return classifier; } }