/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* DecisionTable.java
* Copyright (C) 1999 Mark Hall
*
*/
package weka.classifiers.rules;
import weka.classifiers.Classifier;
import weka.classifiers.DistributionClassifier;
import weka.classifiers.Evaluation;
import weka.classifiers.lazy.IBk;
import weka.classifiers.lazy.IB1;
import java.io.*;
import java.util.*;
import weka.core.*;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Remove;
/**
* Class for building and using a simple decision table majority classifier.
* For more information see: <p>
*
* Kohavi R. (1995).<i> The Power of Decision Tables.</i> In Proc
* European Conference on Machine Learning.<p>
*
* Valid options are: <p>
*
* -S num <br>
* Number of fully expanded non improving subsets to consider
* before terminating a best first search.
* (Default = 5) <p>
*
* -X num <br>
* Use cross validation to evaluate features. Use number of folds = 1 for
* leave one out CV. (Default = leave one out CV) <p>
*
* -I <br>
* Use nearest neighbour instead of global table majority. <p>
*
* -R <br>
* Prints the decision table. <p>
*
* @author Mark Hall (mhall@cs.waikato.ac.nz)
* @version $Revision: 1.1.1.1 $
*/
public class DecisionTable
extends DistributionClassifier
implements OptionHandler, WeightedInstancesHandler,
AdditionalMeasureProducer {
/** The hashtable used to hold training instances */
private Hashtable m_entries;
/** Holds the final feature set */
private int [] m_decisionFeatures;
/** Discretization filter */
private Filter m_disTransform;
/** Filter used to remove columns discarded by feature selection */
private Remove m_delTransform;
/** IB1 used to classify non matching instances rather than majority class */
private IBk m_ibk;
/** Holds the training instances */
private Instances m_theInstances;
/** The number of attributes in the dataset */
private int m_numAttributes;
/** The number of instances in the dataset */
private int m_numInstances;
/** Class is nominal */
private boolean m_classIsNominal;
/** Output debug info */
private boolean m_debug;
/** Use the IBk classifier rather than majority class */
private boolean m_useIBk;
/** Display Rules */
private boolean m_displayRules;
/**
* Maximum number of fully expanded non improving subsets for a best
* first search.
*/
private int m_maxStale;
/** Number of folds for cross validating feature sets */
private int m_CVFolds;
/** Random numbers for use in cross validation */
private Random m_rr;
/** Holds the majority class */
private double m_majority;
/**
* Class for a node in a linked list. Used in best first search.
*/
public class Link {
/** The group */
BitSet m_group;
/** The merit */
double m_merit;
/**
* The constructor.
*
* @param gr the group
* @param mer the merit
*/
public Link (BitSet gr, double mer) {
m_group = (BitSet)gr.clone();
m_merit = mer;
}
/**
* Gets the group.
*/
public BitSet getGroup() {
return m_group;
}
/**
* Gets the merit.
*/
public double getMerit() {
return m_merit;
}
/**
* Returns string representation.
*/
public String toString() {
return ("Node: "+m_group.toString()+" "+m_merit);
}
}
/**
* Class for handling a linked list. Used in best first search.
* Extends the Vector class.
*/
public class LinkedList extends FastVector {
/**
* Removes an element (Link) at a specific index from the list.
*
* @param index the index of the element to be removed.
*/
public void removeLinkAt(int index) throws Exception {
if ((index >= 0) && (index < size())) {
removeElementAt(index);
} else {
throw new Exception("index out of range (removeLinkAt)");
}
}
/**
* Returns the element (Link) at a specific index from the list.
*
* @param index the index of the element to be returned.
*/
public Link getLinkAt(int index) throws Exception {
if (size()==0) {
throw new Exception("List is empty (getLinkAt)");
} else if ((index >= 0) && (index < size())) {
return ((Link)(elementAt(index)));
} else {
throw new Exception("index out of range (getLinkAt)");
}
}
/**
* Aadds an element (Link) to the list.
*
* @param gr the feature set specification
* @param mer the "merit" of this feature set
*/
public void addToList(BitSet gr, double mer) {
Link newL = new Link(gr, mer);
if (size()==0) {
addElement(newL);
}
else if (mer > ((Link)(firstElement())).getMerit()) {
insertElementAt(newL,0);
} else {
int i = 0;
int size = size();
boolean done = false;
while ((!done) && (i < size)) {
if (mer > ((Link)(elementAt(i))).getMerit()) {
insertElementAt(newL,i);
done = true;
} else if (i == size-1) {
addElement(newL);
done = true;
} else {
i++;
}
}
}
}
}
/**
* Class providing keys to the hash table
*/
public class hashKey implements Serializable {
/** Array of attribute values for an instance */
private double [] attributes;
/** True for an index if the corresponding attribute value is missing. */
private boolean [] missing;
/** The values */
private String [] values;
/** The key */
private int key;
/**
* Constructor for a hashKey
*
* @param t an instance from which to generate a key
* @param numAtts the number of attributes
*/
public hashKey(Instance t, int numAtts) throws Exception {
int i;
int cindex = t.classIndex();
key = -999;
attributes = new double [numAtts];
missing = new boolean [numAtts];
for (i=0;i<numAtts;i++) {
if (i == cindex) {
missing[i] = true;
} else {
if ((missing[i] = t.isMissing(i)) == false) {
attributes[i] = t.value(i);
}
}
}
}
/**
* Convert a hash entry to a string
*
* @param t the set of instances
* @param maxColWidth width to make the fields
*/
public String toString(Instances t, int maxColWidth) {
int i;
int cindex = t.classIndex();
StringBuffer text = new StringBuffer();
for (i=0;i<attributes.length;i++) {
if (i != cindex) {
if (missing[i]) {
text.append("?");
for (int j=0;j<maxColWidth;j++) {
text.append(" ");
}
} else {
String ss = t.attribute(i).value((int)attributes[i]);
StringBuffer sb = new StringBuffer(ss);
for (int j=0;j < (maxColWidth-ss.length()+1); j++) {
sb.append(" ");
}
text.append(sb);
}
}
}
return text.toString();
}
/**
* Constructor for a hashKey
*
* @param t an array of feature values
*/
public hashKey(double [] t) {
int i;
int l = t.length;
key = -999;
attributes = new double [l];
missing = new boolean [l];
for (i=0;i<l;i++) {
if (t[i] == Double.MAX_VALUE) {
missing[i] = true;
} else {
missing[i] = false;
attributes[i] = t[i];
}
}
}
/**
* Calculates a hash code
*
* @return the hash code as an integer
*/
public int hashCode() {
int hv = 0;
if (key != -999)
return key;
for (int i=0;i<attributes.length;i++) {
if (missing[i]) {
hv += (i*13);
} else {
hv += (i * 5 * (attributes[i]+1));
}
}
if (key == -999) {
key = hv;
}
return hv;
}
/**
* Tests if two instances are equal
*
* @param b a key to compare with
*/
public boolean equals(Object b) {
if ((b == null) || !(b.getClass().equals(this.getClass()))) {
return false;
}
boolean ok = true;
boolean l;
if (b instanceof hashKey) {
hashKey n = (hashKey)b;
for (int i=0;i<attributes.length;i++) {
l = n.missing[i];
if (missing[i] || l) {
if ((missing[i] && !l) || (!missing[i] && l)) {
ok = false;
break;
}
} else {
if (attributes[i] != n.attributes[i]) {
ok = false;
break;
}
}
}
} else {
return false;
}
return ok;
}
/**
* Prints the hash code
*/
public void print_hash_code() {
System.out.println("Hash val: "+hashCode());
}
}
/**
* Inserts an instance into the hash table
*
* @param inst instance to be inserted
* @exception Exception if the instance can't be inserted
*/
private void insertIntoTable(Instance inst, double [] instA)
throws Exception {
double [] tempClassDist2;
double [] newDist;
hashKey thekey;
if (instA != null) {
thekey = new hashKey(instA);
} else {
thekey = new hashKey(inst, inst.numAttributes());
}
// see if this one is already in the table
tempClassDist2 = (double []) m_entries.get(thekey);
if (tempClassDist2 == null) {
if (m_classIsNominal) {
newDist = new double [m_theInstances.classAttribute().numValues()];
newDist[(int)inst.classValue()] = inst.weight();
// add to the table
m_entries.put(thekey, newDist);
} else {
newDist = new double [2];
newDist[0] = inst.classValue() * inst.weight();
newDist[1] = inst.weight();
// add to the table
m_entries.put(thekey, newDist);
}
} else {
// update the distribution for this instance
if (m_classIsNominal) {
tempClassDist2[(int)inst.classValue()]+=inst.weight();
// update the table
m_entries.put(thekey, tempClassDist2);
} else {
tempClassDist2[0] += (inst.classValue() * inst.weight());
tempClassDist2[1] += inst.weight();
// update the table
m_entries.put(thekey, tempClassDist2);
}
}
}
/**
* Classifies an instance for internal leave one out cross validation
* of feature sets
*
* @param instance instance to be "left out" and classified
* @param instA feature values of the selected features for the instance
* @return the classification of the instance
*/
double classifyInstanceLeaveOneOut(Instance instance, double [] instA)
throws Exception {
hashKey thekey;
double [] tempDist;
double [] normDist;
thekey = new hashKey(instA);
if (m_classIsNominal) {
// if this one is not in the table
if ((tempDist = (double [])m_entries.get(thekey)) == null) {
throw new Error("This should never happen!");
} else {
normDist = new double [tempDist.length];
System.arraycopy(tempDist,0,normDist,0,tempDist.length);
normDist[(int)instance.classValue()] -= instance.weight();
// update the table
// first check to see if the class counts are all zero now
boolean ok = false;
for (int i=0;i<normDist.length;i++) {
if (!Utils.eq(normDist[i],0.0)) {
ok = true;
break;
}
}
if (ok) {
Utils.normalize(normDist);
return Utils.maxIndex(normDist);
} else {
return m_majority;
}
}
// return Utils.maxIndex(tempDist);
} else {
// see if this one is already in the table
if ((tempDist = (double[])m_entries.get(thekey)) != null) {
normDist = new double [tempDist.length];
System.arraycopy(tempDist,0,normDist,0,tempDist.length);
normDist[0] -= (instance.classValue() * instance.weight());
normDist[1] -= instance.weight();
if (Utils.eq(normDist[1],0.0)) {
return m_majority;
} else {
return (normDist[0] / normDist[1]);
}
} else {
throw new Error("This should never happen!");
}
}
// shouldn't get here
// return 0.0;
}
/**
* Calculates the accuracy on a test fold for internal cross validation
* of feature sets
*
* @param fold set of instances to be "left out" and classified
* @param fs currently selected feature set
* @return the accuracy for the fold
*/
double classifyFoldCV(Instances fold, int [] fs) throws Exception {
int i;
int ruleCount = 0;
int numFold = fold.numInstances();
int numCl = m_theInstances.classAttribute().numValues();
double [][] class_distribs = new double [numFold][numCl];
double [] instA = new double [fs.length];
double [] normDist;
hashKey thekey;
double acc = 0.0;
int classI = m_theInstances.classIndex();
Instance inst;
if (m_classIsNominal) {
normDist = new double [numCl];
} else {
normDist = new double [2];
}
// first *remove* instances
for (i=0;i<numFold;i++) {
inst = fold.instance(i);
for (int j=0;j<fs.length;j++) {
if (fs[j] == classI) {
instA[j] = Double.MAX_VALUE; // missing for the class
} else if (inst.isMissing(fs[j])) {
instA[j] = Double.MAX_VALUE;
} else{
instA[j] = inst.value(fs[j]);
}
}
thekey = new hashKey(instA);
if ((class_distribs[i] = (double [])m_entries.get(thekey)) == null) {
throw new Error("This should never happen!");
} else {
if (m_classIsNominal) {
class_distribs[i][(int)inst.classValue()] -= inst.weight();
} else {
class_distribs[i][0] -= (inst.classValue() * inst.weight());
class_distribs[i][1] -= inst.weight();
}
ruleCount++;
}
}
// now classify instances
for (i=0;i<numFold;i++) {
inst = fold.instance(i);
System.arraycopy(class_distribs[i],0,normDist,0,normDist.length);
if (m_classIsNominal) {
boolean ok = false;
for (int j=0;j<normDist.length;j++) {
if (!Utils.eq(normDist[j],0.0)) {
ok = true;
break;
}
}
if (ok) {
Utils.normalize(normDist);
if (Utils.maxIndex(normDist) == inst.classValue())
acc += inst.weight();
} else {
if (inst.classValue() == m_majority) {
acc += inst.weight();
}
}
} else {
if (Utils.eq(normDist[1],0.0)) {
acc += ((inst.weight() * (m_majority - inst.classValue())) *
(inst.weight() * (m_majority - inst.classValue())));
} else {
double t = (normDist[0] / normDist[1]);
acc += ((inst.weight() * (t - inst.classValue())) *
(inst.weight() * (t - inst.classValue())));
}
}
}
// now re-insert instances
for (i=0;i<numFold;i++) {
inst = fold.instance(i);
if (m_classIsNominal) {
class_distribs[i][(int)inst.classValue()] += inst.weight();
} else {
class_distribs[i][0] += (inst.classValue() * inst.weight());
class_distribs[i][1] += inst.weight();
}
}
return acc;
}
/**
* Evaluates a feature subset by cross validation
*
* @param feature_set the subset to be evaluated
* @param num_atts the number of attributes in the subset
* @return the estimated accuracy
* @exception Exception if subset can't be evaluated
*/
private double estimateAccuracy(BitSet feature_set, int num_atts)
throws Exception {
int i;
Instances newInstances;
int [] fs = new int [num_atts];
double acc = 0.0;
double [][] evalArray;
double [] instA = new double [num_atts];
int classI = m_theInstances.classIndex();
int index = 0;
for (i=0;i<m_numAttributes;i++) {
if (feature_set.get(i)) {
fs[index++] = i;
}
}
// create new hash table
m_entries = new Hashtable((int)(m_theInstances.numInstances() * 1.5));
// insert instances into the hash table
for (i=0;i<m_numInstances;i++) {
Instance inst = m_theInstances.instance(i);
for (int j=0;j<fs.length;j++) {
if (fs[j] == classI) {
instA[j] = Double.MAX_VALUE; // missing for the class
} else if (inst.isMissing(fs[j])) {
instA[j] = Double.MAX_VALUE;
} else {
instA[j] = inst.value(fs[j]);
}
}
insertIntoTable(inst, instA);
}
if (m_CVFolds == 1) {
// calculate leave one out error
for (i=0;i<m_numInstances;i++) {
Instance inst = m_theInstances.instance(i);
for (int j=0;j<fs.length;j++) {
if (fs[j] == classI) {
instA[j] = Double.MAX_VALUE; // missing for the class
} else if (inst.isMissing(fs[j])) {
instA[j] = Double.MAX_VALUE;
} else {
instA[j] = inst.value(fs[j]);
}
}
double t = classifyInstanceLeaveOneOut(inst, instA);
if (m_classIsNominal) {
if (t == inst.classValue()) {
acc+=inst.weight();
}
} else {
acc += ((inst.weight() * (t - inst.classValue())) *
(inst.weight() * (t - inst.classValue())));
}
// weight_sum += inst.weight();
}
} else {
m_theInstances.randomize(m_rr);
m_theInstances.stratify(m_CVFolds);
// calculate 10 fold cross validation error
for (i=0;i<m_CVFolds;i++) {
Instances insts = m_theInstances.testCV(m_CVFolds,i);
acc += classifyFoldCV(insts, fs);
}
}
if (m_classIsNominal) {
return (acc / m_theInstances.sumOfWeights());
} else {
return -(Math.sqrt(acc / m_theInstances.sumOfWeights()));
}
}
/**
* Returns a String representation of a feature subset
*
* @param sub BitSet representation of a subset
* @return String containing subset
*/
private String printSub(BitSet sub) {
int i;
String s="";
for (int jj=0;jj<m_numAttributes;jj++) {
if (sub.get(jj)) {
s += " "+(jj+1);
}
}
return s;
}
/**
* Does a best first search
*/
private void best_first() throws Exception {
int i,j,classI,count=0,fc,tree_count=0;
int evals=0;
BitSet best_group, temp_group;
int [] stale;
double [] best_merit;
double merit;
boolean z;
boolean added;
Link tl;
Hashtable lookup = new Hashtable((int)(200.0*m_numAttributes*1.5));
LinkedList bfList = new LinkedList();
best_merit = new double[1]; best_merit[0] = 0.0;
stale = new int[1]; stale[0] = 0;
best_group = new BitSet(m_numAttributes);
// Add class to initial subset
classI = m_theInstances.classIndex();
best_group.set(classI);
best_merit[0] = estimateAccuracy(best_group, 1);
if (m_debug)
System.out.println("Accuracy of initial subset: "+best_merit[0]);
// add the initial group to the list
bfList.addToList(best_group,best_merit[0]);
// add initial subset to the hashtable
lookup.put(best_group,"");
while (stale[0] < m_maxStale) {
added = false;
// finished search?
if (bfList.size()==0) {
stale[0] = m_maxStale;
break;
}
// copy the feature set at the head of the list
tl = bfList.getLinkAt(0);
temp_group = (BitSet)(tl.getGroup().clone());
// remove the head of the list
bfList.removeLinkAt(0);
for (i=0;i<m_numAttributes;i++) {
// if (search_direction == 1)
z = ((i != classI) && (!temp_group.get(i)));
if (z) {
// set the bit (feature to add/delete) */
temp_group.set(i);
/* if this subset has been seen before, then it is already in
the list (or has been fully expanded) */
BitSet tt = (BitSet)temp_group.clone();
if (lookup.containsKey(tt) == false) {
fc = 0;
for (int jj=0;jj<m_numAttributes;jj++) {
if (tt.get(jj)) {
fc++;
}
}
merit = estimateAccuracy(tt, fc);
if (m_debug) {
System.out.println("evaluating: "+printSub(tt)+" "+merit);
}
// is this better than the best?
// if (search_direction == 1)
z = ((merit - best_merit[0]) > 0.00001);
// else
// z = ((best_merit[0] - merit) > 0.00001);
if (z) {
if (m_debug) {
System.out.println("new best feature set: "+printSub(tt)+
" "+merit);
}
added = true;
stale[0] = 0;
best_merit[0] = merit;
best_group = (BitSet)(temp_group.clone());
}
// insert this one in the list and the hash table
bfList.addToList(tt, merit);
lookup.put(tt,"");
count++;
}
// unset this addition(deletion)
temp_group.clear(i);
}
}
/* if we haven't added a new feature subset then full expansion
of this node hasn't resulted in anything better */
if (!added) {
stale[0]++;
}
}
// set selected features
for (i=0,j=0;i<m_numAttributes;i++) {
if (best_group.get(i)) {
j++;
}
}
m_decisionFeatures = new int[j];
for (i=0,j=0;i<m_numAttributes;i++) {
if (best_group.get(i)) {
m_decisionFeatures[j++] = i;
}
}
}
/**
* Resets the options.
*/
protected void resetOptions() {
m_entries = null;
m_decisionFeatures = null;
m_debug = false;
m_useIBk = false;
m_CVFolds = 1;
m_maxStale = 5;
m_displayRules = false;
}
/**
* Constructor for a DecisionTable
*/
public DecisionTable() {
resetOptions();
}
/**
* Returns an enumeration describing the available options.
*
* @return an enumeration of all the available options.
*/
public Enumeration listOptions() {
Vector newVector = new Vector(5);
newVector.addElement(new Option(
"\tNumber of fully expanded non improving subsets to consider\n" +
"\tbefore terminating a best first search.\n" +
"\tUse in conjunction with -B. (Default = 5)",
"S", 1, "-S <number of non improving nodes>"));
newVector.addElement(new Option(
"\tUse cross validation to evaluate features.\n" +
"\tUse number of folds = 1 for leave one out CV.\n" +
"\t(Default = leave one out CV)",
"X", 1, "-X <number of folds>"));
newVector.addElement(new Option(
"\tUse nearest neighbour instead of global table majority.\n",
"I", 0, "-I"));
newVector.addElement(new Option(
"\tDisplay decision table rules.\n",
"R", 0, "-R"));
return newVector.elements();
}
/**
* Sets the number of folds for cross validation (1 = leave one out)
*
* @param folds the number of folds
*/
public void setCrossVal(int folds) {
m_CVFolds = folds;
}
/**
* Gets the number of folds for cross validation
*
* @return the number of cross validation folds
*/
public int getCrossVal() {
return m_CVFolds;
}
/**
* Sets the number of non improving decision tables to consider
* before abandoning the search.
*
* @param stale the number of nodes
*/
public void setMaxStale(int stale) {
m_maxStale = stale;
}
/**
* Gets the number of non improving decision tables
*
* @return the number of non improving decision tables
*/
public int getMaxStale() {
return m_maxStale;
}
/**
* Sets whether IBk should be used instead of the majority class
*
* @param ibk true if IBk is to be used
*/
public void setUseIBk(boolean ibk) {
m_useIBk = ibk;
}
/**
* Gets whether IBk is being used instead of the majority class
*
* @return true if IBk is being used
*/
public boolean getUseIBk() {
return m_useIBk;
}
/**
* Sets whether rules are to be printed
*
* @param rules true if rules are to be printed
*/
public void setDisplayRules(boolean rules) {
m_displayRules = rules;
}
/**
* Gets whether rules are being printed
*
* @return true if rules are being printed
*/
public boolean getDisplayRules() {
return m_displayRules;
}
/**
* Parses the options for this object.
*
* Valid options are: <p>
*
* -S num <br>
* Number of fully expanded non improving subsets to consider
* before terminating a best first search.
* (Default = 5) <p>
*
* -X num <br>
* Use cross validation to evaluate features. Use number of folds = 1 for
* leave one out CV. (Default = leave one out CV) <p>
*
* -I <br>
* Use nearest neighbour instead of global table majority. <p>
*
* -R <br>
* Prints the decision table. <p>
*
* @param options the list of options as an array of strings
* @exception Exception if an option is not supported
*/
public void setOptions(String[] options) throws Exception {
String optionString;
resetOptions();
optionString = Utils.getOption('X',options);
if (optionString.length() != 0) {
m_CVFolds = Integer.parseInt(optionString);
}
optionString = Utils.getOption('S',options);
if (optionString.length() != 0) {
m_maxStale = Integer.parseInt(optionString);
}
m_useIBk = Utils.getFlag('I',options);
m_displayRules = Utils.getFlag('R',options);
}
/**
* Gets the current settings of the classifier.
*
* @return an array of strings suitable for passing to setOptions
*/
public String [] getOptions() {
String [] options = new String [7];
int current = 0;
options[current++] = "-X"; options[current++] = "" + m_CVFolds;
options[current++] = "-S"; options[current++] = "" + m_maxStale;
if (m_useIBk) {
options[current++] = "-I";
}
if (m_displayRules) {
options[current++] = "-R";
}
while (current < options.length) {
options[current++] = "";
}
return options;
}
/**
* Generates the classifier.
*
* @param data set of instances serving as training data
* @exception Exception if the classifier has not been generated successfully
*/
public void buildClassifier(Instances data) throws Exception {
int i;
m_rr = new Random(1);
m_theInstances = new Instances(data);
m_theInstances.deleteWithMissingClass();
if (m_theInstances.numInstances() == 0) {
throw new Exception("No training instances without missing class!");
}
if (m_theInstances.checkForStringAttributes()) {
throw new UnsupportedAttributeTypeException("Cannot handle string attributes!");
}
if (m_theInstances.classAttribute().isNumeric()) {
m_disTransform = new weka.filters.unsupervised.attribute.Discretize();
m_classIsNominal = false;
// use binned discretisation if the class is numeric
((weka.filters.unsupervised.attribute.Discretize)m_disTransform).
setBins(10);
((weka.filters.unsupervised.attribute.Discretize)m_disTransform).
setInvertSelection(true);
// Discretize all attributes EXCEPT the class
String rangeList = "";
rangeList+=(m_theInstances.classIndex()+1);
//System.out.println("The class col: "+m_theInstances.classIndex());
((weka.filters.unsupervised.attribute.Discretize)m_disTransform).
setAttributeIndices(rangeList);
} else {
m_disTransform = new weka.filters.supervised.attribute.Discretize();
((weka.filters.supervised.attribute.Discretize)m_disTransform).setUseBetterEncoding(true);
m_classIsNominal = true;
}
m_disTransform.setInputFormat(m_theInstances);
m_theInstances = Filter.useFilter(m_theInstances, m_disTransform);
m_numAttributes = m_theInstances.numAttributes();
m_numInstances = m_theInstances.numInstances();
m_majority = m_theInstances.meanOrMode(m_theInstances.classAttribute());
best_first();
// reduce instances to selected features
m_delTransform = new Remove();
m_delTransform.setInvertSelection(true);
// set features to keep
m_delTransform.setAttributeIndicesArray(m_decisionFeatures);
m_delTransform.setInputFormat(m_theInstances);
m_theInstances = Filter.useFilter(m_theInstances, m_delTransform);
// reset the number of attributes
m_numAttributes = m_theInstances.numAttributes();
// create hash table
m_entries = new Hashtable((int)(m_theInstances.numInstances() * 1.5));
// insert instances into the hash table
for (i=0;i<m_numInstances;i++) {
Instance inst = m_theInstances.instance(i);
insertIntoTable(inst, null);
}
// Replace the global table majority with nearest neighbour?
if (m_useIBk) {
m_ibk = new IBk();
m_ibk.buildClassifier(m_theInstances);
}
// Save memory
m_theInstances = new Instances(m_theInstances, 0);
}
/**
* Calculates the class membership probabilities for the given
* test instance.
*
* @param instance the instance to be classified
* @return predicted class probability distribution
* @exception Exception if distribution can't be computed
*/
public double [] distributionForInstance(Instance instance)
throws Exception {
hashKey thekey;
double [] tempDist;
double [] normDist;
m_disTransform.input(instance);
m_disTransform.batchFinished();
instance = m_disTransform.output();
m_delTransform.input(instance);
m_delTransform.batchFinished();
instance = m_delTransform.output();
thekey = new hashKey(instance, instance.numAttributes());
// if this one is not in the table
if ((tempDist = (double [])m_entries.get(thekey)) == null) {
if (m_useIBk) {
tempDist = m_ibk.distributionForInstance(instance);
} else {
if (!m_classIsNominal) {
tempDist = new double[1];
tempDist[0] = m_majority;
} else {
tempDist = new double [m_theInstances.classAttribute().numValues()];
tempDist[(int)m_majority] = 1.0;
}
}
} else {
if (!m_classIsNominal) {
normDist = new double[1];
normDist[0] = (tempDist[0] / tempDist[1]);
tempDist = normDist;
} else {
// normalise distribution
normDist = new double [tempDist.length];
System.arraycopy(tempDist,0,normDist,0,tempDist.length);
Utils.normalize(normDist);
tempDist = normDist;
}
}
return tempDist;
}
/**
* Returns a string description of the features selected
*
* @return a string of features
*/
public String printFeatures() {
int i;
String s = "";
for (i=0;i<m_decisionFeatures.length;i++) {
if (i==0) {
s = ""+(m_decisionFeatures[i]+1);
} else {
s += ","+(m_decisionFeatures[i]+1);
}
}
return s;
}
/**
* Returns the number of rules
* @return the number of rules
*/
public double measureNumRules() {
return m_entries.size();
}
/**
* Returns an enumeration of the additional measure names
* @return an enumeration of the measure names
*/
public Enumeration enumerateMeasures() {
Vector newVector = new Vector(1);
newVector.addElement("measureNumRules");
return newVector.elements();
}
/**
* Returns the value of the named measure
* @param measureName the name of the measure to query for its value
* @return the value of the named measure
* @exception IllegalArgumentException if the named measure is not supported
*/
public double getMeasure(String additionalMeasureName) {
if (additionalMeasureName.compareTo("measureNumRules") == 0) {
return measureNumRules();
} else {
throw new IllegalArgumentException(additionalMeasureName
+ " not supported (DecisionTable)");
}
}
/**
* Returns a description of the classifier.
*
* @return a description of the classifier as a string.
*/
public String toString() {
if (m_entries == null) {
return "Decision Table: No model built yet.";
} else {
StringBuffer text = new StringBuffer();
text.append("Decision Table:"+
"\n\nNumber of training instances: "+m_numInstances+
"\nNumber of Rules : "+m_entries.size()+"\n");
if (m_useIBk) {
text.append("Non matches covered by IB1.\n");
} else {
text.append("Non matches covered by Majority class.\n");
}
text.append("Best first search for feature set,\nterminated after "+
m_maxStale+" non improving subsets.\n");
text.append("Evaluation (for feature selection): CV ");
if (m_CVFolds > 1) {
text.append("("+m_CVFolds+" fold) ");
} else {
text.append("(leave one out) ");
}
text.append("\nFeature set: "+printFeatures());
if (m_displayRules) {
// find out the max column width
int maxColWidth = 0;
for (int i=0;i<m_theInstances.numAttributes();i++) {
if (m_theInstances.attribute(i).name().length() > maxColWidth) {
maxColWidth = m_theInstances.attribute(i).name().length();
}
if (m_classIsNominal || (i != m_theInstances.classIndex())) {
Enumeration e = m_theInstances.attribute(i).enumerateValues();
while (e.hasMoreElements()) {
String ss = (String)e.nextElement();
if (ss.length() > maxColWidth) {
maxColWidth = ss.length();
}
}
}
}
text.append("\n\nRules:\n");
StringBuffer tm = new StringBuffer();
for (int i=0;i<m_theInstances.numAttributes();i++) {
if (m_theInstances.classIndex() != i) {
int d = maxColWidth - m_theInstances.attribute(i).name().length();
tm.append(m_theInstances.attribute(i).name());
for (int j=0;j<d+1;j++) {
tm.append(" ");
}
}
}
tm.append(m_theInstances.attribute(m_theInstances.classIndex()).name()+" ");
for (int i=0;i<tm.length()+10;i++) {
text.append("=");
}
text.append("\n");
text.append(tm);
text.append("\n");
for (int i=0;i<tm.length()+10;i++) {
text.append("=");
}
text.append("\n");
Enumeration e = m_entries.keys();
while (e.hasMoreElements()) {
hashKey tt = (hashKey)e.nextElement();
text.append(tt.toString(m_theInstances,maxColWidth));
double [] ClassDist = (double []) m_entries.get(tt);
if (m_classIsNominal) {
int m = Utils.maxIndex(ClassDist);
try {
text.append(m_theInstances.classAttribute().value(m)+"\n");
} catch (Exception ee) {
System.out.println(ee.getMessage());
}
} else {
text.append((ClassDist[0] / ClassDist[1])+"\n");
}
}
for (int i=0;i<tm.length()+10;i++) {
text.append("=");
}
text.append("\n");
text.append("\n");
}
return text.toString();
}
}
/**
* Main method for testing this class.
*
* @param argv the command-line options
*/
public static void main(String [] argv) {
Classifier scheme;
try {
scheme = new DecisionTable();
System.out.println(Evaluation.evaluateModel(scheme,argv));
}
catch (Exception e) {
e.printStackTrace();
System.out.println(e.getMessage());
}
}
}