/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* IBk.java
* Copyright (C) 1999 Stuart Inglis,Len Trigg,Eibe Frank
*
*/
package weka.classifiers.lazy;
import weka.classifiers.Classifier;
import weka.classifiers.DistributionClassifier;
import weka.classifiers.Evaluation;
import weka.classifiers.UpdateableClassifier;
import java.io.*;
import java.util.*;
import weka.core.KDTree;
import weka.core.DistanceFunction;
import weka.core.EuclideanDistance;
import weka.core.Utils;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.UnsupportedAttributeTypeException;
import weka.core.WeightedInstancesHandler;
/**
* <i>K</i>-nearest neighbour classifier. For more information, see <p>
*
* Aha, D., and D. Kibler (1991) "Instance-based learning algorithms",
* <i>Machine Learning</i>, vol.6, pp. 37-66.<p>
*
* Valid options are:<p>
*
* -K num <br>
* Set the number of nearest neighbours to use in prediction
* (default 1) <p>
*
* -W num <br>
* Set a fixed window size for incremental train/testing. As
* new training instances are added, oldest instances are removed
* to maintain the number of training instances at this size.
* (default no window) <p>
*
* -D <br>
* Neighbours will be weighted by the inverse of their distance
* when voting. (default equal weighting) <p>
*
* -F <br>
* Neighbours will be weighted by their similarity when voting.
* (default equal weighting) <p>
*
* -X <br>
* Selects the number of neighbours to use by hold-one-out cross
* validation, with an upper limit given by the -K option. <p>
*
* -S <br>
* When k is selected by cross-validation for numeric class attributes,
* minimize mean-squared error. (default mean absolute error) <p>
*
* -N <br>
* Turns off normalization. <p>
*
* -E <kdtree class><br>
* KDTrees class and its options (can only use the same distance function
* as XMeans).<p>
*
* @author Gabi Schmidberger (gabi@cs.waikato.ac.nz)
* @author Stuart Inglis (singlis@cs.waikato.ac.nz)
* @author Len Trigg (trigg@cs.waikato.ac.nz)
* @author Eibe Frank (eibe@cs.waikato.ac.nz)
* @version $Revision: 1.1.1.1 $
*/
public class IBk extends DistributionClassifier implements
OptionHandler, UpdateableClassifier, WeightedInstancesHandler {
/*
* A class for storing data about a neighbouring instance
*/
private class NeighbourNode {
/** The neighbour instance */
private Instance m_Instance;
/** The distance from the current instance to this neighbour */
private double m_Distance;
/** A link to the next neighbour instance */
private NeighbourNode m_Next;
/**
* Create a new neighbour node.
*
* @param distance the distance to the neighbour
* @param instance the neighbour instance
* @param next the next neighbour node
*/
public NeighbourNode(double distance, Instance instance,
NeighbourNode next) {
m_Distance = distance;
m_Instance = instance;
m_Next = next;
}
/**
* Create a new neighbour node that doesn't link to any other nodes.
* @param distance the distance to the neighbour
* @param instance the neighbour instance
*/
public NeighbourNode(double distance, Instance instance) {
this(distance, instance, null);
}
}
/*
* A class for a linked list to store the nearest k neighbours
* to an instance. We use a list so that we can take care of
* cases where multiple neighbours are the same distance away.
* i.e. the minimum length of the list is k.
*/
private class NeighbourList {
/** The first node in the list */
private NeighbourNode m_First;
/** The last node in the list */
private NeighbourNode m_Last;
/** The number of nodes to attempt to maintain in the list */
private int m_Length = 1;
/**
* Creates the neighbourlist with a desired length
*
* @param length the length of list to attempt to maintain
*/
public NeighbourList(int length) {
m_Length = length;
}
/**
* Gets whether the list is empty.
* @return true if so
*/
public boolean isEmpty() {
return (m_First == null);
}
/**
* Gets the current length of the list.
* @return the current length of the list
*/
public int currentLength() {
int i = 0;
NeighbourNode current = m_First;
while (current != null) {
i++;
current = current.m_Next;
}
return i;
}
/**
* Inserts an instance neighbour into the list, maintaining the list
* sorted by distance.
*
* @param distance the distance to the instance
* @param instance the neighbouring instance
*/
public void insertSorted(double distance, Instance instance) {
if (isEmpty()) {
m_First = m_Last = new NeighbourNode(distance, instance);
} else {
NeighbourNode current = m_First;
if (distance < m_First.m_Distance) {// Insert at head
m_First = new NeighbourNode(distance, instance, m_First);
} else { // Insert further down the list
for( ;(current.m_Next != null) &&
(current.m_Next.m_Distance < distance);
current = current.m_Next);
current.m_Next = new NeighbourNode(distance, instance,
current.m_Next);
if (current.equals(m_Last)) {
m_Last = current.m_Next;
}
}
// Trip down the list until we've got k list elements (or more if the
// distance to the last elements is the same).
int valcount = 0;
for(current = m_First; current.m_Next != null;
current = current.m_Next) {
valcount++;
if ((valcount >= m_Length) && (current.m_Distance !=
current.m_Next.m_Distance)) {
m_Last = current;
current.m_Next = null;
break;
}
}
}
}
/**
* Prunes the list to contain the k nearest neighbours. If there are
* multiple neighbours at the k'th distance, all will be kept.
*
* @param k the number of neighbours to keep in the list.
*/
public void pruneToK(int k) {
if (isEmpty()) {
return;
}
if (k < 1) {
k = 1;
}
int currentK = 0;
double currentDist = m_First.m_Distance;
NeighbourNode current = m_First;
for(; current.m_Next != null; current = current.m_Next) {
currentK++;
currentDist = current.m_Distance;
if ((currentK >= k) && (currentDist != current.m_Next.m_Distance)) {
m_Last = current;
current.m_Next = null;
break;
}
}
}
/**
* Prints out the contents of the neighbourlist
*/
public void printList() {
if (isEmpty()) {
System.out.println("Empty list");
} else {
NeighbourNode current = m_First;
//System.out.print("Node:");
while (current != null) {
System.out.println("Node: instance " + current.m_Instance
+ ", distance " + current.m_Distance);
//System.out.print("distance " + current.m_Distance);
current = current.m_Next;
}
System.out.println();
}
}
}
/** KDTrees class if KDTrees are used */
private KDTree m_KDTree = null;
/** The training instances used for classification. */
protected Instances m_Train;
/** The number of class values (or 1 if predicting numeric) */
protected int m_NumClasses;
/** The class attribute type */
protected int m_ClassType;
/** The number of neighbours to use for classification (currently) */
protected int m_kNN;
/** Distance functions */
protected DistanceFunction m_DistanceF = null;
/**
* The value of kNN provided by the user. This may differ from
* m_kNN if cross-validation is being used
*/
protected int m_kNNUpper;
/**
* Whether the value of k selected by cross validation has
* been invalidated by a change in the training instances
*/
protected boolean m_kNNValid;
/**
* The maximum number of training instances allowed. When
* this limit is reached, old training instances are removed,
* so the training data is "windowed". Set to 0 for unlimited
* numbers of instances.
*/
protected int m_WindowSize;
/** Whether the neighbours should be distance-weighted */
protected int m_DistanceWeighting;
/** Whether to select k by cross validation */
protected boolean m_CrossValidate;
/**
* Whether to minimise mean squared error rather than mean absolute
* error when cross-validating on numeric prediction tasks
*/
protected boolean m_MeanSquared;
/** True if debugging output should be printed */
boolean m_Debug;
/** True if normalization is turned off */
protected boolean m_DontNormalize;
/* Define possible instance weighting methods */
public static final int WEIGHT_NONE = 1;
public static final int WEIGHT_INVERSE = 2;
public static final int WEIGHT_SIMILARITY = 4;
public static final Tag [] TAGS_WEIGHTING = {
new Tag(WEIGHT_NONE, "No distance weighting"),
new Tag(WEIGHT_INVERSE, "Weight by 1/distance"),
new Tag(WEIGHT_SIMILARITY, "Weight by 1-distance")
};
/** The number of attributes the contribute to a prediction */
protected double m_NumAttributesUsed;
/** Ranges of the universe of data, lowest value, highest value and width */
protected double [][] m_Ranges;
/** Index in ranges for LOW and HIGH and WIDTH */
protected static int R_MIN = 0;
protected static int R_MAX = 1;
protected static int R_WIDTH = 2;
/**
* IBk classifier. Simple instance-based learner that uses the class
* of the nearest k training instances for the class of the test
* instances.
*
* @param k the number of nearest neighbours to use for prediction
*/
public IBk(int k) {
init();
setKNN(k);
}
/**
* IB1 classifer. Instance-based learner. Predicts the class of the
* single nearest training instance for each test instance.
*/
public IBk() {
init();
}
/**
* Get the value of Debug.
* @return Value of Debug.
*/
public boolean getDebug() {
return m_Debug;
}
/**
* Set the value of Debug.
* @param newDebug Value to assign to Debug.
*/
public void setDebug(boolean newDebug) {
m_Debug = newDebug;
}
/**
* Sets the KDTree class.
* @param k a KDTree object with all options set
*/
public void setKDTree(KDTree k) {
m_KDTree = k;
}
/**
* Gets the KDTree class.
* @return flag if KDTrees are used
*/
public KDTree getKDTree() {
return m_KDTree;
}
/**
* Gets the KDTree specification string, which contains the class name of
* the KDTree class and any options to the KDTree
* @return the KDTree string.
*/
protected String getKDTreeSpec() {
KDTree c = getKDTree();
if (c instanceof OptionHandler) {
return c.getClass().getName() + " "
+ Utils.joinOptions(((OptionHandler)c).getOptions());
}
return c.getClass().getName();
}
/**
* Set the number of neighbours the learner is to use.
* @param k the number of neighbours.
*/
public void setKNN(int k) {
m_kNN = k;
m_kNNUpper = k;
m_kNNValid = false;
}
/**
* Gets the number of neighbours the learner will use.
*
* @return the number of neighbours
*/
public int getKNN() {
return m_kNN;
}
/**
* Gets the maximum number of instances allowed in the training
* pool. The addition of new instances above this value will result
* in old instances being removed. A value of 0 signifies no limit
* to the number of training instances.
*
* @return Value of WindowSize
*/
public int getWindowSize() {
return m_WindowSize;
}
/**
* Sets the maximum number of instances allowed in the training
* pool. The addition of new instances above this value will result
* in old instances being removed. A value of 0 signifies no limit
* to the number of training instances.
*
* @param newWindowSize Value to assign to WindowSize.
*/
public void setWindowSize(int newWindowSize) {
m_WindowSize = newWindowSize;
}
/**
* Gets the distance weighting method used. Will be one of
* WEIGHT_NONE, WEIGHT_INVERSE, or WEIGHT_SIMILARITY
*
* @return the distance weighting method used.
*/
public SelectedTag getDistanceWeighting() {
return new SelectedTag(m_DistanceWeighting, TAGS_WEIGHTING);
}
/**
* Sets the distance weighting method used. Values other than
* WEIGHT_NONE, WEIGHT_INVERSE, or WEIGHT_SIMILARITY will be ignored.
*
* @param newDistanceWeighting the distance weighting method to use
*/
public void setDistanceWeighting(SelectedTag newMethod) {
if (newMethod.getTags() == TAGS_WEIGHTING) {
m_DistanceWeighting = newMethod.getSelectedTag().getID();
}
}
/**
* Gets whether the mean squared error is used rather than mean
* absolute error when doing cross-validation.
*
* @return true if so.
*/
public boolean getMeanSquared() {
return m_MeanSquared;
}
/**
* Sets whether the mean squared error is used rather than mean
* absolute error when doing cross-validation.
*
* @param newMeanSquared true if so.
*/
public void setMeanSquared(boolean newMeanSquared) {
m_MeanSquared = newMeanSquared;
}
/**
* Gets whether hold-one-out cross-validation will be used
* to select the best k value
*
* @return true if cross-validation will be used.
*/
public boolean getCrossValidate() {
return m_CrossValidate;
}
/**
* Sets whether hold-one-out cross-validation will be used
* to select the best k value
*
* @param newCrossValidate true if cross-validation should be used.
*/
public void setCrossValidate(boolean newCrossValidate) {
m_CrossValidate = newCrossValidate;
}
/**
* Get the number of training instances the classifier is currently using
*/
public int getNumTraining() {
return m_Train.numInstances();
}
/**
* Get an attributes minimum observed value
*/
public double getAttributeMin(int index) throws Exception {
if (m_Ranges == null) {
throw new Exception("Minimum value for attribute not available!");
}
return m_Ranges[index][R_MIN];
}
/**
* Get an attributes maximum observed value
*/
public double getAttributeMax(int index) throws Exception {
if (m_Ranges == null) {
throw new Exception("Maximum value for attribute not available!");
}
return m_Ranges[index][R_MAX];
}
/**
* Gets whether normalization is turned off.
* @return Value of DontNormalize.
*/
public boolean getNoNormalization() {
return m_DontNormalize;
}
/**
* Set whether normalization is turned off.
* @param v Value to assign to DontNormalize.
*/
public void setNoNormalization(boolean v) {
m_DontNormalize = v;
}
/**
* Generates the classifier.
* @param instances set of instances serving as training data
* @exception Exception if the classifier has not been generated successfully
*/
public void buildClassifier(Instances instances) throws Exception {
if (instances.classIndex() < 0) {
throw new Exception ("No class attribute assigned to instances");
}
if (instances.checkForStringAttributes()) {
throw new UnsupportedAttributeTypeException("Cannot handle string attributes!");
}
try {
m_NumClasses = instances.numClasses();
m_ClassType = instances.classAttribute().type();
} catch (Exception ex) {
throw new Error("This should never be reached");
}
// Throw away training instances with missing class
m_Train = new Instances(instances, 0, instances.numInstances());
m_Train.deleteWithMissingClass();
// Throw away initial instances until within the specified window size
if ((m_WindowSize > 0) && (instances.numInstances() > m_WindowSize)) {
m_Train = new Instances(m_Train,
m_Train.numInstances()-m_WindowSize,
m_WindowSize);
}
// make ranges if needed for normalization and/or for the KDTree
if ((!m_DontNormalize) || (m_KDTree != null)) {
// Initializes and calculates the ranges for the training instances
m_Ranges = m_Train.initializeRanges();
// Instances.printRanges(m_Ranges);
}
// if already some instances here, then build KDTree
if ((m_KDTree != null) && (m_Train.numInstances() > 0)) {
m_KDTree.buildKDTree(m_Train);
OOPS("KDTree build in buildclassifier");
OOPS(" " + m_KDTree.toString());
}
// Compute the number of attributes that contribute
// to each prediction
m_NumAttributesUsed = 0.0;
for (int i = 0; i < m_Train.numAttributes(); i++) {
if ((i != m_Train.classIndex()) &&
(m_Train.attribute(i).isNominal() ||
m_Train.attribute(i).isNumeric())) {
m_NumAttributesUsed += 1.0;
}
}
// Invalidate any currently cross-validation selected k
m_kNNValid = false;
}
/**
* Adds the supplied instance to the training set
*
* @param instance the instance to add
* @exception Exception if instance could not be incorporated
* successfully
*/
public void updateClassifier(Instance instance) throws Exception {
if (m_Train.equalHeaders(instance.dataset()) == false) {
throw new Exception("Incompatible instance types");
}
if (instance.classIsMissing()) {
return;
}
// update ranges
// but only if normalize flag is on or KDTree is chosen
if ((!m_DontNormalize) || (m_KDTree != null)) {
m_Ranges = Instances.updateRanges(instance, m_Ranges);
}
// add instance to training set
m_Train.add(instance);
// update KDTree
if (m_KDTree != null) {
if (m_KDTree.isValid() && (m_KDTree.numInstances() > 0))
m_KDTree.updateKDTree(instance);
}
m_kNNValid = false;
if ((m_WindowSize > 0) && (m_Train.numInstances() > m_WindowSize)) {
while (m_Train.numInstances() > m_WindowSize) {
m_Train.delete(0);
if (m_KDTree != null)
m_KDTree.setValid(false);
}
}
}
/**
* Calculates the class membership probabilities for the given test instance.
*
* @param instance the instance to be classified
* @return predicted class probability distribution
* @exception Exception if an error occurred during the prediction
*/
public double [] distributionForInstance(Instance instance)
throws Exception {
if (m_Train.numInstances() == 0) {
throw new Exception("No training instances!");
}
// cut instances to windowsize
if ((m_WindowSize > 0) && (m_Train.numInstances() > m_WindowSize)) {
m_kNNValid = false;
while (m_Train.numInstances() > m_WindowSize) {
m_Train.delete(0);
m_KDTree.setValid(false);
}
}
if ((m_KDTree != null) && (!m_KDTree.isValid())) {
m_KDTree.buildKDTree(m_Train);
//OOPS("KDTree build in distributionForInstance");
//OOPS(" " + m_KDTree.toString());
}
// Select k by cross validation
if (!m_kNNValid && (m_CrossValidate) && (m_kNN > 1)) {
crossValidate();
}
// update ranges - for norm()-method
if (!m_DontNormalize) {
m_Ranges = Instances.updateRanges(instance, m_Ranges);
}
// update ranges for norm()-methode in Distance class of KDTree
if (m_KDTree != null) {
m_KDTree.addLooslyInstance(instance);
}
// find neighbours and make distribution
NeighbourList neighbourlist = findNeighbours(instance);
return makeDistribution(neighbourlist);
}
/**
* Returns an enumeration describing the available options.
*
* @return an enumeration of all the available options.
*/
public Enumeration listOptions() {
Vector newVector = new Vector(9);
newVector.addElement(new Option(
"\tWeight neighbours by the inverse of their distance\n"
+"\t(use when k > 1)",
"D", 0, "-D"));
newVector.addElement(new Option(
"\tWeight neighbours by 1 - their distance\n"
+"\t(use when k > 1)",
"F", 0, "-F"));
newVector.addElement(new Option(
"\tNumber of nearest neighbours (k) used in classification.\n"
+"\t(Default = 1)",
"K", 1,"-K <number of neighbours>"));
newVector.addElement(new Option(
"\tMinimise mean squared error rather than mean absolute\n"
+"\terror when using -X option with numeric prediction.",
"S", 0,"-S"));
newVector.addElement(new Option(
"\tMaximum number of training instances maintained.\n"
+"\tTraining instances are dropped FIFO. (Default = no window)",
"W", 1,"-W <window size>"));
newVector.addElement(new Option(
"\tSelect the number of nearest neighbours between 1\n"
+"\tand the k value specified using hold-one-out evaluation\n"
+"\ton the training data (use when k > 1)",
"X", 0,"-X"));
newVector.addElement(new Option(
"\tDon't normalize the data.\n",
"N", 0, "-N"));
newVector.addElement(new Option(
"\tFull class name of KDTree class to use, followed\n" +
"\tby scheme options.\n" +
"\teg: \"weka.core.KDTree -P\"\n" +
"(default = no KDTree class used).",
"E", 1, "-E <KDTree class specification>"));
return newVector.elements();
}
/**
* Parses a given list of options. Valid options are:<p>
*
* -K num <br>
* Set the number of nearest neighbours to use in prediction
* (default 1) <p>
*
* -W num <br>
* Set a fixed window size for incremental train/testing. As
* new training instances are added, oldest instances are removed
* to maintain the number of training instances at this size.
* (default no window) <p>
*
* -D <br>
* Neighbours will be weighted by the inverse of their distance
* when voting. (default equal weighting) <p>
*
* -F <br>
* Neighbours will be weighted by their similarity when voting.
* (default equal weighting) <p>
*
* -X <br>
* Select the number of neighbours to use by hold-one-out cross
* validation, with an upper limit given by the -K option. <p>
*
* -S <br>
* When k is selected by cross-validation for numeric class attributes,
* minimize mean-squared error. (default mean absolute error) <p>
*
* @param options the list of options as an array of strings
* @exception Exception if an option is not supported
*/
public void setOptions(String[] options) throws Exception {
String knnString = Utils.getOption('K', options);
if (knnString.length() != 0) {
setKNN(Integer.parseInt(knnString));
} else {
setKNN(1);
}
String windowString = Utils.getOption('W', options);
if (windowString.length() != 0) {
setWindowSize(Integer.parseInt(windowString));
} else {
setWindowSize(0);
}
if (Utils.getFlag('D', options)) {
setDistanceWeighting(new SelectedTag(WEIGHT_INVERSE, TAGS_WEIGHTING));
} else if (Utils.getFlag('F', options)) {
setDistanceWeighting(new SelectedTag(WEIGHT_SIMILARITY, TAGS_WEIGHTING));
} else {
setDistanceWeighting(new SelectedTag(WEIGHT_NONE, TAGS_WEIGHTING));
}
setCrossValidate(Utils.getFlag('X', options));
setMeanSquared(Utils.getFlag('S', options));
setNoNormalization(Utils.getFlag('N', options));
String funcString = Utils.getOption('E', options);
if (funcString.length() != 0) {
String [] funcSpec = Utils.splitOptions(funcString);
if (funcSpec.length == 0) {
throw new Exception("Invalid function specification string");
}
String funcName = funcSpec[0];
funcSpec[0] = "";
Class cl = KDTree.class;
setKDTree((KDTree) Utils.forName(KDTree.class, funcName, funcSpec));
}
Utils.checkForRemainingOptions(options);
}
/**
* Gets the current settings of IBk.
*
* @return an array of strings suitable for passing to setOptions()
*/
public String [] getOptions() {
String [] options = new String [11];
int current = 0;
options[current++] = "-K"; options[current++] = "" + getKNN();
options[current++] = "-W"; options[current++] = "" + m_WindowSize;
if (getCrossValidate()) {
options[current++] = "-X";
}
if (getMeanSquared()) {
options[current++] = "-S";
}
if (m_DistanceWeighting == WEIGHT_INVERSE) {
options[current++] = "-D";
} else if (m_DistanceWeighting == WEIGHT_SIMILARITY) {
options[current++] = "-F";
}
if (m_DontNormalize) {
options[current++] = "-N";
}
if (getKDTree() != null) {
options[current++] = "-E";
options[current++] = "" + getKDTreeSpec();
}
while (current < options.length) {
options[current++] = "";
}
return options;
}
/**
* Returns a description of this classifier.
*
* @return a description of this classifier as a string.
*/
public String toString() {
if (m_Train == null) {
return "IBk: No model built yet.";
}
if (!m_kNNValid && m_CrossValidate) {
crossValidate();
}
String result = "IB1 instance-based classifier\n" +
"using " + m_kNN;
switch (m_DistanceWeighting) {
case WEIGHT_INVERSE:
result += " inverse-distance-weighted";
break;
case WEIGHT_SIMILARITY:
result += " similarity-weighted";
break;
}
result += " nearest neighbour(s) for classification\n";
if (m_WindowSize != 0) {
result += "using a maximum of "
+ m_WindowSize + " (windowed) training instances\n";
}
return result;
}
/**
* Initialise scheme variables.
*/
private void init() {
setKNN(1);
m_WindowSize = 0;
m_DistanceWeighting = WEIGHT_NONE;
m_CrossValidate = false;
m_MeanSquared = false;
m_DontNormalize = false;
}
/**
* Calculates the distance between two instances
*
* @param test the first instance
* @param train the second instance
* @return the distance between the two given instances, between 0 and 1
*/
private double distance(Instance first, Instance second) {
if (!Instances.inRanges(first,m_Ranges))
OOPS("Not in ranges");
if (!Instances.inRanges(second,m_Ranges))
OOPS("Not in ranges");
double distance = 0;
int firstI, secondI;
for (int p1 = 0, p2 = 0;
p1 < first.numValues() || p2 < second.numValues();) {
if (p1 >= first.numValues()) {
firstI = m_Train.numAttributes();
} else {
firstI = first.index(p1);
}
if (p2 >= second.numValues()) {
secondI = m_Train.numAttributes();
} else {
secondI = second.index(p2);
}
if (firstI == m_Train.classIndex()) {
p1++; continue;
}
if (secondI == m_Train.classIndex()) {
p2++; continue;
}
double diff;
if (firstI == secondI) {
diff = difference(firstI,
first.valueSparse(p1),
second.valueSparse(p2));
p1++; p2++;
} else if (firstI > secondI) {
diff = difference(secondI,
0, second.valueSparse(p2));
p2++;
} else {
diff = difference(firstI,
first.valueSparse(p1), 0);
p1++;
}
distance += diff * diff;
}
distance = Math.sqrt(distance / m_NumAttributesUsed);
return distance;
}
/**
* Computes the difference between two given attribute
* values.
*/
private double difference(int index, double val1, double val2) {
switch (m_Train.attribute(index).type()) {
case Attribute.NOMINAL:
// If attribute is nominal
if (Instance.isMissingValue(val1) ||
Instance.isMissingValue(val2) ||
((int)val1 != (int)val2)) {
return 1;
} else {
return 0;
}
case Attribute.NUMERIC:
// If attribute is numeric
if (Instance.isMissingValue(val1) ||
Instance.isMissingValue(val2)) {
if (Instance.isMissingValue(val1) &&
Instance.isMissingValue(val2)) {
return 1;
} else {
double diff;
if (Instance.isMissingValue(val2)) {
diff = norm(val1, index);
} else {
diff = norm(val2, index);
}
if (diff < 0.5) {
diff = 1.0 - diff;
}
return diff;
}
} else {
return norm(val1, index) - norm(val2, index);
}
default:
return 0;
}
}
/**
* Normalizes a given value of a numeric attribute.
*
* @param x the value to be normalized
* @param i the attribute's index
*/
private double norm(double x, int i) {
if (m_DontNormalize) {
return x;
} else if (Double.isNaN(m_Ranges[i][R_MIN]) ||
Utils.eq(m_Ranges[i][R_MAX],m_Ranges[i][R_MIN])) {
return 0;
} else {
return (x - m_Ranges[i][R_MIN]) / (m_Ranges[i][R_MAX] - m_Ranges[i][R_MIN]);
}
}
/**
* Updates the minimum and maximum values for all the attributes
* based on a new instance.
*
* @param instance the new instance
*/
private void updateMinMax(Instance instance) {
for (int j = 0;j < m_Train.numAttributes(); j++) {
if (!instance.isMissing(j)) {
if (Double.isNaN(m_Ranges[j][R_MIN])) {
m_Ranges[j][R_MIN] = instance.value(j);
m_Ranges[j][R_MAX] = instance.value(j);
} else {
if (instance.value(j) < m_Ranges[j][R_MIN]) {
m_Ranges[j][R_MIN] = instance.value(j);
} else {
if (instance.value(j) > m_Ranges[j][R_MAX]) {
m_Ranges[j][R_MAX] = instance.value(j);
}
}
}
}
}
}
/**
* Build the list of nearest k neighbours to the given test instance.
*
* @param instance the instance to search for neighbours of
* @return a list of neighbours
*/
private NeighbourList findNeighbours(Instance instance) throws Exception {
double distance;
NeighbourList neighbourlist = new NeighbourList(m_kNN);
// dont work with kdtree
if (m_KDTree == null) {
Enumeration enum = m_Train.enumerateInstances();
int i = 0;
while (enum.hasMoreElements()) {
Instance trainInstance = (Instance) enum.nextElement();
if (instance != trainInstance) { // for hold-one-out cross-validation
distance = distance(instance, trainInstance);
if (neighbourlist.isEmpty() || (i < m_kNN) ||
(distance <= neighbourlist.m_Last.m_Distance)) {
neighbourlist.insertSorted(distance, trainInstance);
}
i++;
}
}
}
else {
// work with KDTree
double[] distanceList = new double[m_KDTree.numInstances()];
int[] instanceList = new int[m_KDTree.numInstances()];
int numOfNearest = m_KDTree.findKNearestNeighbour(instance, m_kNN,
instanceList, distanceList);
for (int i = 0; i < numOfNearest; i++) {
neighbourlist.insertSorted(distanceList[i],
m_KDTree.getInstances().instance(instanceList[i]));
}
}
//debug
//OOPS("Target: "+instance+" found "+neighbourlist.currentLength() + " neighbours\n");
//neighbourlist.printList();
return neighbourlist;
}
/**
* Turn the list of nearest neighbours into a probability distribution
*
* @param neighbourlist the list of nearest neighbouring instances
* @return the probability distribution
*/
private double [] makeDistribution(NeighbourList neighbourlist)
throws Exception {
double total = 0, weight;
double [] distribution = new double [m_NumClasses];
// Set up a correction to the estimator
if (m_ClassType == Attribute.NOMINAL) {
for(int i = 0; i < m_NumClasses; i++) {
distribution[i] = 1.0 / Math.max(1,m_Train.numInstances());
}
total = (double)m_NumClasses / Math.max(1,m_Train.numInstances());
}
if (!neighbourlist.isEmpty()) {
// Collect class counts
NeighbourNode current = neighbourlist.m_First;
while (current != null) {
switch (m_DistanceWeighting) {
case WEIGHT_INVERSE:
weight = 1.0 / (current.m_Distance + 0.001); // to avoid div by zero
break;
case WEIGHT_SIMILARITY:
weight = 1.0 - current.m_Distance;
break;
default: // WEIGHT_NONE:
weight = 1.0;
break;
}
weight *= current.m_Instance.weight();
try {
switch (m_ClassType) {
case Attribute.NOMINAL:
distribution[(int)current.m_Instance.classValue()] += weight;
break;
case Attribute.NUMERIC:
distribution[0] += current.m_Instance.classValue() * weight;
break;
}
} catch (Exception ex) {
throw new Error("Data has no class attribute!");
}
total += weight;
current = current.m_Next;
}
}
// Normalise distribution
if (total > 0) {
Utils.normalize(distribution, total);
}
// double [] distribution = new double [m_NumClasses];
return distribution;
}
/**
* Select the best value for k by hold-one-out cross-validation.
* If the class attribute is nominal, classification error is
* minimised. If the class attribute is numeric, mean absolute
* error is minimised
*/
private void crossValidate() {
try {
double [] performanceStats = new double [m_kNNUpper];
double [] performanceStatsSq = new double [m_kNNUpper];
for(int i = 0; i < m_kNNUpper; i++) {
performanceStats[i] = 0;
performanceStatsSq[i] = 0;
}
m_kNN = m_kNNUpper;
Instance instance;
NeighbourList neighbourlist;
for(int i = 0; i < m_Train.numInstances(); i++) {
if (m_Debug && (i % 50 == 0)) {
System.err.print("Cross validating "
+ i + "/" + m_Train.numInstances() + "\r");
}
instance = m_Train.instance(i);
neighbourlist = findNeighbours(instance);
for(int j = m_kNNUpper - 1; j >= 0; j--) {
// Update the performance stats
double [] distribution = makeDistribution(neighbourlist);
double thisPrediction = Utils.maxIndex(distribution);
if (m_Train.classAttribute().isNumeric()) {
double err = thisPrediction - instance.classValue();
performanceStatsSq[j] += err * err; // Squared error
performanceStats[j] += Math.abs(err); // Absolute error
} else {
if (thisPrediction != instance.classValue()) {
performanceStats[j] ++; // Classification error
}
}
if (j >= 1) {
neighbourlist.pruneToK(j);
}
}
}
// Display the results of the cross-validation
for(int i = 0; i < m_kNNUpper; i++) {
if (m_Debug) {
System.err.print("Hold-one-out performance of " + (i + 1)
+ " neighbours " );
}
if (m_Train.classAttribute().isNumeric()) {
if (m_Debug) {
if (m_MeanSquared) {
System.err.println("(RMSE) = "
+ Math.sqrt(performanceStatsSq[i]
/ m_Train.numInstances()));
} else {
System.err.println("(MAE) = "
+ performanceStats[i]
/ m_Train.numInstances());
}
}
} else {
if (m_Debug) {
System.err.println("(%ERR) = "
+ 100.0 * performanceStats[i]
/ m_Train.numInstances());
}
}
}
// Check through the performance stats and select the best
// k value (or the lowest k if more than one best)
double [] searchStats = performanceStats;
if (m_Train.classAttribute().isNumeric() && m_MeanSquared) {
searchStats = performanceStatsSq;
}
double bestPerformance = Double.NaN;
int bestK = 1;
for(int i = 0; i < m_kNNUpper; i++) {
if (Double.isNaN(bestPerformance)
|| (bestPerformance > searchStats[i])) {
bestPerformance = searchStats[i];
bestK = i + 1;
}
}
m_kNN = bestK;
if (m_Debug) {
System.err.println("Selected k = " + bestK);
}
m_kNNValid = true;
} catch (Exception ex) {
throw new Error("Couldn't optimize by cross-validation: "
+ex.getMessage());
}
}
/**
* Used for debug println's.
* @param output string that is printed
*/
private void OOPS(String output) {
System.out.println(output);
}
/**
* Main method for testing this class.
*
* @param argv should contain command line options (see setOptions)
*/
public static void main(String [] argv) {
try {
System.out.println(Evaluation.evaluateModel(new IBk(), argv));
} catch (Exception e) {
e.printStackTrace();
System.err.println(e.getMessage());
}
}
}