package func;
import dist.*;
import dist.Distribution;
import dist.DiscreteDistribution;
import func.inst.KDTree;
import shared.*;
import shared.DataSet;
import shared.DataSetDescription;
import shared.Instance;
/**
* A knn classifier
* @author Andrew Guillory gtg008g@mail.gatech.edu
* @version 1.0
*/
public class KNNClassifier extends AbstractConditionalDistribution implements FunctionApproximater {
/**
* The distance measure
*/
private DistanceMeasure distanceMeasure;
/**
* The range limit for the neighbors
*/
private double range;
/**
* The number of neighbors
*/
private int k;
/**
* Whether or not to weight by distance
*/
private boolean weightByDistance;
/**
* The range of classes
*/
private int classRange;
/**
* The tree
*/
private KDTree tree;
/**
* Make a new knn classifier
*/
public KNNClassifier() {
this(1, new EuclideanDistance());
}
/**
* Build a new classifier
* @param examples the examples
* @param k the k value
* @param measure the distance measure
*/
public KNNClassifier(int k,
DistanceMeasure measure) {
this(k, false, measure, -1);
}
/**
* Build a new classifier
* @param examples the examples
* @param k the k value
* @param weight the weight
* @param measure the distance measure
*/
public KNNClassifier(int k, boolean weight,
DistanceMeasure measure) {
this(k, weight, measure, -1);
}
/**
* Build a new classifier
* @param examples the examples
* @param k the k value
* @param weight the weight
* @param measure the distance measure
* @param range the range
*/
public KNNClassifier(int k, boolean weight,
DistanceMeasure measure, double range) {
this.k = k;
this.weightByDistance = weight;
this.range = range;
this.distanceMeasure = measure;
}
/**
* Estimate from data
* @param examples the examples
*/
public void estimate(DataSet examples) {
if (examples.getDescription() == null) {
examples.setDescription(new DataSetDescription(examples));
}
classRange = examples.getDescription().getLabelDescription().getDiscreteRange();
tree = new KDTree(examples, distanceMeasure);
}
/**
* Get the class distribution
* @param data the data
* @return the class distribution
*/
public Distribution distributionFor(Instance data) {
double[] distribution = new double[classRange];
Object[] results;
if (range > 0) {
results = tree.knnrange(data, k, range);
} else {
results = tree.knn(data, k);
}
for (int i = 0; i < results.length; i++) {
Instance neighbor = (Instance) results[i];
if (weightByDistance) {
distribution[neighbor.getLabel().getDiscrete()] +=
neighbor.getWeight()/distanceMeasure.value(data, neighbor);
} else {
distribution[neighbor.getLabel().getDiscrete()] +=
neighbor.getWeight();
}
}
double sum = 0;
for (int i = 0; i < distribution.length; i++) {
sum += distribution[i];
}
if (Double.isInfinite(sum)) {
sum = 0;
for (int i = 0; i < distribution.length; i++) {
if (Double.isInfinite(distribution[i])) {
distribution[i] = 1;
sum ++;
} else {
distribution[i] = 0;
}
}
}
for (int i = 0; i < distribution.length; i++) {
distribution[i] /= sum;
}
return new DiscreteDistribution(distribution);
}
/**
* Get the classification for an example
* @param data the data to get the classification for
* @return the classification
*/
public Instance value(Instance data) {
return distributionFor(data).mode();
}
/**
* Get the distance measure
* @return the distance measure
*/
public DistanceMeasure getDistanceMeasure() {
return distanceMeasure;
}
/**
* Get the k value
* @return the k value
*/
public int getK() {
return k;
}
/**
* Does it weight by distance
* @return true if it does
*/
public boolean isWeightByDistance() {
return weightByDistance;
}
/**
* Set the distance measure
* @param measure the new measure
*/
public void setDistanceMeasure(DistanceMeasure measure) {
distanceMeasure = measure;
}
/**
* Set the k value
* @param i the new k
*/
public void setK(int i) {
k = i;
}
/**
* Set the new weighting policy
* @param b the new policy
*/
public void setWeightByDistance(boolean b) {
weightByDistance = b;
}
}