/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* LWR.java
* Copyright (C) 1999, 2002 Len Trigg, Eibe Frank
*
*/
package weka.classifiers.lazy;
import weka.classifiers.functions.LinearRegression;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.functions.LinearRegression;
import weka.classifiers.UpdateableClassifier;
import java.io.*;
import java.util.*;
import weka.core.*;
/**
* Locally-weighted regression. Uses an instance-based algorithm to assign
* instance weights which are then used by a linear regression model. For
* more information, see<p>
*
* Atkeson, C., A. Moore, and S. Schaal (1996) <i>Locally weighted
* learning</i>
* <a href="ftp://ftp.cc.gatech.edu/pub/people/cga/air1.ps.gz">download
* postscript</a>. <p>
*
* Valid options are:<p>
*
* -D <br>
* Produce debugging output. <p>
*
* -K num <br>
* Set the number of neighbours used for setting kernel bandwidth.
* (default all) <p>
*
* -W num <br>
* Set the weighting kernel shape to use. 1 = Inverse, 2 = Gaussian.
* (default 0 = Linear) <p>
*
* -S num <br>
* Set the attriute selection method to use. 1 = None, 2 = Greedy
* (default 0 = M5' method) <p>
*
* -C <br>
* Do not try to eliminate colinear attributes <p>
*
* -R num <br>
* The ridge parameter (default 1.0e-8) <p>
*
* @author Len Trigg (trigg@cs.waikato.ac.nz)
* @version $Revision: 1.1.1.1 $
*/
public class LWR extends Classifier
implements OptionHandler, UpdateableClassifier,
WeightedInstancesHandler {
/** The training instances used for classification. */
protected Instances m_Train;
/** The minimum values for numeric attributes. */
protected double [] m_Min;
/** The maximum values for numeric attributes. */
protected double [] m_Max;
/** True if debugging output should be printed */
protected boolean m_Debug;
/** The number of neighbours used to select the kernel bandwidth */
protected int m_kNN = -1;
/** The weighting kernel method currently selected */
protected int m_WeightKernel = LINEAR;
/** True if m_kNN should be set to all instances */
protected boolean m_UseAllK = true;
/** The available kernel weighting methods */
protected static final int LINEAR = 0;
protected static final int INVERSE = 1;
protected static final int GAUSS = 2;
/** The linear regression object */
protected LinearRegression lr = new LinearRegression();
/**
* Returns an enumeration describing the available options.
*
* @return an enumeration of all the available options.
*/
public Enumeration listOptions() {
Vector newVector = new Vector(5);
newVector.addElement(new Option("\tProduce debugging output.\n"
+ "\t(default no debugging output)",
"D", 0, "-D"));
newVector.addElement(new Option("\tSet the number of neighbors used to set"
+ " the kernel bandwidth.\n"
+ "\t(default all)",
"K", 1, "-K <number of neighbours>"));
newVector.addElement(new Option("\tSet the weighting kernel shape to use."
+ " 1 = Inverse, 2 = Gaussian.\n"
+ "\t(default 0 = Linear)",
"W", 1,"-W <number of weighting method>"));
newVector.addElement(new Option("\tSet the attribute selection method"
+ " to use. 1 = None, 2 = Greedy.\n"
+ "\t(default 0 = M5' method)",
"S", 1, "-S <number of selection method>"));
newVector.addElement(new Option("\tDo not try to eliminate colinear"
+ " attributes.\n",
"C", 0, "-C"));
newVector.addElement(new Option("\tSet ridge parameter (default 1.0e-8).\n",
"R", 1, "-R <double>"));
return newVector.elements();
}
/**
* Parses a given list of options. Valid options are:<p>
*
* -D <br>
* Produce debugging output. <p>
*
* -K num <br>
* Set the number of neighbours used for setting kernel bandwidth.
* (default all) <p>
*
* -W num <br>
* Set the weighting kernel shape to use. 1 = Inverse, 2 = Gaussian.
* (default 0 = Linear) <p>
*
* -S num <br>
* Set the attriute selection method to use. 1 = None, 2 = Greedy
* (default 0 = M5' method) <p>
*
* -C <br>
* Do not try to eliminate colinear attributes <p>
*
* -R num <br>
* The ridge parameter (default 1.0e-8) <p>
*
* @param options the list of options as an array of strings
* @exception Exception if an option is not supported
*/
public void setOptions(String[] options) throws Exception {
String knnString = Utils.getOption('K', options);
if (knnString.length() != 0) {
setKNN(Integer.parseInt(knnString));
} else {
setKNN(0);
}
String weightString = Utils.getOption('W', options);
if (weightString.length() != 0) {
setWeightingKernel(Integer.parseInt(weightString));
} else {
setWeightingKernel(LINEAR);
}
String selectionString = Utils.getOption('S', options);
if (selectionString.length() != 0) {
setAttributeSelectionMethod(new SelectedTag(Integer
.parseInt(selectionString),
LinearRegression.TAGS_SELECTION));
} else {
setAttributeSelectionMethod(new SelectedTag(LinearRegression.SELECTION_M5,
LinearRegression.TAGS_SELECTION));
}
setEliminateColinearAttributes(!Utils.getFlag('C', options));
String ridgeString = Utils.getOption('R', options);
if (ridgeString.length() != 0) {
lr.setRidge(new Double(ridgeString).doubleValue());
} else {
lr.setRidge(1.0e-8);
}
setDebug(Utils.getFlag('D', options));
}
/**
* Gets the current settings of the classifier.
*
* @return an array of strings suitable for passing to setOptions
*/
public String [] getOptions() {
String [] options = new String [9];
int current = 0;
if (getDebug()) {
options[current++] = "-D";
}
options[current++] = "-W"; options[current++] = "" + getWeightingKernel();
if (!m_UseAllK) {
options[current++] = "-K"; options[current++] = "" + getKNN();
}
options[current++] = "-S";
options[current++] = "" + getAttributeSelectionMethod()
.getSelectedTag().getID();
if (!getEliminateColinearAttributes()) {
options[current++] = "-C";
}
options[current++] = "-R";
options[current++] = "" + lr.getRidge();
while (current < options.length) {
options[current++] = "";
}
return options;
}
/**
* Get the value of EliminateColinearAttributes.
*
* @return Value of EliminateColinearAttributes.
*/
public boolean getEliminateColinearAttributes() {
return lr.getEliminateColinearAttributes();
}
/**
* Set the value of EliminateColinearAttributes.
*
* @param newEliminateColinearAttributes Value to assign to EliminateColinearAttributes.
*/
public void setEliminateColinearAttributes(boolean newEliminateColinearAttributes) {
lr.setEliminateColinearAttributes(newEliminateColinearAttributes);
}
/**
* Sets the method used to select attributes for use in the
* linear regression.
*
* @param method the attribute selection method to use.
*/
public void setAttributeSelectionMethod(SelectedTag method) {
lr.setAttributeSelectionMethod(method);
}
/**
* Gets the method used to select attributes for use in the
* linear regression.
*
* @return the method to use.
*/
public SelectedTag getAttributeSelectionMethod() {
return lr.getAttributeSelectionMethod();
}
/**
* Sets whether debugging output should be produced
*
* @param debug true if debugging output should be printed
*/
public void setDebug(boolean debug) {
m_Debug = debug;
}
/**
* SGts whether debugging output should be produced
*
* @return true if debugging output should be printed
*/
public boolean getDebug() {
return m_Debug;
}
/**
* Sets the number of neighbours used for kernel bandwidth setting.
* The bandwidth is taken as the distance to the kth neighbour.
*
* @param knn the number of neighbours included inside the kernel
* bandwidth, or 0 to specify using all neighbors.
*/
public void setKNN(int knn) {
m_kNN = knn;
if (knn <= 0) {
m_kNN = 0;
m_UseAllK = true;
} else {
m_UseAllK = false;
}
}
/**
* Gets the number of neighbours used for kernel bandwidth setting.
* The bandwidth is taken as the distance to the kth neighbour.
*
* @return the number of neighbours included inside the kernel
* bandwidth, or 0 for all neighbours
*/
public int getKNN() {
return m_kNN;
}
/**
* Sets the kernel weighting method to use. Must be one of LINEAR,
* INVERSE, or GAUSS, other values are ignored.
*
* @param kernel the new kernel method to use. Must be one of LINEAR,
* INVERSE, or GAUSS
*/
public void setWeightingKernel(int kernel) {
if ((kernel != LINEAR)
&& (kernel != INVERSE)
&& (kernel != GAUSS)) {
return;
}
m_WeightKernel = kernel;
}
/**
* Gets the kernel weighting method to use.
*
* @return the new kernel method to use. Will be one of LINEAR,
* INVERSE, or GAUSS
*/
public int getWeightingKernel() {
return m_WeightKernel;
}
/**
* Gets an attributes minimum observed value
*
* @param index the index of the attribute
* @return the minimum observed value
*/
protected double getAttributeMin(int index) {
return m_Min[index];
}
/**
* Gets an attributes maximum observed value
*
* @param index the index of the attribute
* @return the maximum observed value
*/
protected double getAttributeMax(int index) {
return m_Max[index];
}
/**
* Generates the classifier.
*
* @param instances set of instances serving as training data
* @exception Exception if the classifier has not been generated successfully
*/
public void buildClassifier(Instances instances) throws Exception {
if (instances.classIndex() < 0) {
throw new Exception("No class attribute assigned to instances");
}
if (instances.classAttribute().type() != Attribute.NUMERIC) {
throw new UnsupportedClassTypeException("Class attribute must be numeric");
}
if (instances.checkForStringAttributes()) {
throw new UnsupportedAttributeTypeException("Cannot handle string attributes!");
}
// Throw away training instances with missing class
m_Train = new Instances(instances, 0, instances.numInstances());
m_Train.deleteWithMissingClass();
// Calculate the minimum and maximum values
m_Min = new double [m_Train.numAttributes()];
m_Max = new double [m_Train.numAttributes()];
for (int i = 0; i < m_Train.numAttributes(); i++) {
m_Min[i] = m_Max[i] = Double.NaN;
}
for (int i = 0; i < m_Train.numInstances(); i++) {
updateMinMax(m_Train.instance(i));
}
}
/**
* Adds the supplied instance to the training set
*
* @param instance the instance to add
* @exception Exception if instance could not be incorporated
* successfully
*/
public void updateClassifier(Instance instance) throws Exception {
if (m_Train.equalHeaders(instance.dataset()) == false) {
throw new Exception("Incompatible instance types");
}
if (!instance.classIsMissing()) {
updateMinMax(instance);
m_Train.add(instance);
}
}
/**
* Predicts the class value for the given test instance.
*
* @param instance the instance to be classified
* @return the predicted class value
* @exception Exception if an error occurred during the prediction
*/
public double classifyInstance(Instance instance) throws Exception {
if (m_Train.numInstances() == 0) {
throw new Exception("No training instances!");
}
updateMinMax(instance);
// Get the distances to each training instance
double [] distance = new double [m_Train.numInstances()];
for (int i = 0; i < m_Train.numInstances(); i++) {
distance[i] = distance(instance, m_Train.instance(i));
}
int [] sortKey = Utils.sort(distance);
if (m_Debug) {
System.out.println("Instance Distances");
for (int i = 0; i < distance.length; i++) {
System.out.println("" + distance[sortKey[i]]);
}
}
// Determine the bandwidth
int k = sortKey.length - 1;
if (!m_UseAllK && (m_kNN < k)) {
k = m_kNN;
}
double bandwidth = distance[sortKey[k]];
if (bandwidth == distance[sortKey[0]]) {
for (int i = k; i < sortKey.length; i++) {
if (distance[sortKey[i]] > bandwidth) {
bandwidth = distance[sortKey[i]];
break;
}
}
if (bandwidth == distance[sortKey[0]]) {
bandwidth *= 10; // Include them all
}
}
// Rescale the distances by the bandwidth
for (int i = 0; i < distance.length; i++) {
distance[i] = distance[i] / bandwidth;
}
// Pass the distances through a weighting kernel
for (int i = 0; i < distance.length; i++) {
switch (m_WeightKernel) {
case LINEAR:
distance[i] = Math.max(1.0001 - distance[i], 0);
break;
case INVERSE:
distance[i] = 1.0 / (1.0 + distance[i]);
break;
case GAUSS:
distance[i] = Math.exp(-distance[i] * distance[i]);
break;
}
}
if (m_Debug) {
System.out.println("Instance Weights");
for (int i = 0; i < distance.length; i++) {
System.out.println("" + distance[i]);
}
}
// Set the weights on a copy of the training data
Instances weightedTrain = new Instances(m_Train, 0);
for (int i = 0; i < distance.length; i++) {
double weight = distance[sortKey[i]];
if (weight < 1e-20) {
break;
}
Instance newInst = (Instance) m_Train.instance(sortKey[i]).copy();
newInst.setWeight(newInst.weight() * weight);
weightedTrain.add(newInst);
}
if (m_Debug) {
System.out.println("Kept " + weightedTrain.numInstances() + " out of "
+ m_Train.numInstances() + " instances");
}
// Create a weighted linear regression
lr.buildClassifier(weightedTrain);
if (m_Debug) {
System.out.println("Classifying test instance: " + instance);
System.out.println("Built regression model:\n"
+ lr.toString());
}
// Return the linear regression's prediction
return lr.classifyInstance(instance);
}
/**
* Returns a description of this classifier.
*
* @return a description of this classifier as a string.
*/
public String toString() {
if (m_Train == null) {
return "Locally weighted regression: No model built yet.";
}
String result = "Locally weighted regression\n"
+ "===========================\n";
switch (m_WeightKernel) {
case LINEAR:
result += "Using linear weighting kernels\n";
break;
case INVERSE:
result += "Using inverse-distance weighting kernels\n";
break;
case GAUSS:
result += "Using gaussian weighting kernels\n";
break;
}
result += "Using " + (m_UseAllK ? "all" : "" + m_kNN) + " neighbours";
return result;
}
/**
* Calculates the distance between two instances
*
* @param test the first instance
* @param train the second instance
* @return the distance between the two given instances, between 0 and 1
*/
private double distance(Instance first, Instance second) {
double diff, distance = 0;
int numAttribsUsed = 0;
for(int i = 0; i < m_Train.numAttributes(); i++) {
if (i == m_Train.classIndex()) {
continue;
}
switch (m_Train.attribute(i).type()) {
case Attribute.NOMINAL:
// If attribute is nominal
numAttribsUsed++;
if (first.isMissing(i) || second.isMissing(i) ||
((int)first.value(i) != (int)second.value(i))) {
diff = 1;
} else {
diff = 0;
}
break;
case Attribute.NUMERIC:
// If attribute is numeric
numAttribsUsed++;
if (first.isMissing(i) || second.isMissing(i)) {
if (first.isMissing(i) && second.isMissing(i)) {
diff = 1;
} else {
if (second.isMissing(i)) {
diff = norm(first.value(i),i);
} else {
diff = norm(second.value(i),i);
}
if (diff < 0.5) {
diff = 1.0-diff;
}
}
} else {
diff = norm(first.value(i),i) - norm(second.value(i),i);
}
break;
default:
diff = 0;
break;
}
distance += diff * diff;
}
return Math.sqrt(distance / numAttribsUsed);
}
/**
* Normalizes a given value of a numeric attribute.
*
* @param x the value to be normalized
* @param i the attribute's index
*/
private double norm(double x,int i) {
if (Double.isNaN(m_Min[i]) || Utils.eq(m_Max[i], m_Min[i])) {
return 0;
} else {
return (x - m_Min[i]) / (m_Max[i] - m_Min[i]);
}
}
/**
* Updates the minimum and maximum values for all the attributes
* based on a new instance.
*
* @param instance the new instance
*/
private void updateMinMax(Instance instance) {
for (int j = 0; j < m_Train.numAttributes(); j++) {
if (!instance.isMissing(j)) {
if (Double.isNaN(m_Min[j])) {
m_Min[j] = instance.value(j);
m_Max[j] = instance.value(j);
} else if (instance.value(j) < m_Min[j]) {
m_Min[j] = instance.value(j);
} else if (instance.value(j) > m_Max[j]) {
m_Max[j] = instance.value(j);
}
}
}
}
/**
* Main method for testing this class.
*
* @param argv the options
*/
public static void main(String [] argv) {
try {
System.out.println(Evaluation.evaluateModel(
new LWR(), argv));
} catch (Exception e) {
System.err.println(e.getMessage());
}
}
}