/***********************************************************************
This file is part of KEEL-software, the Data Mining tool for regression,
classification, clustering, pattern mining and so on.
Copyright (C) 2004-2010
F. Herrera (herrera@decsai.ugr.es)
L. S�nchez (luciano@uniovi.es)
J. Alcal�-Fdez (jalcala@decsai.ugr.es)
S. Garc�a (sglopez@ujaen.es)
A. Fern�ndez (alberto.fernandez@ujaen.es)
J. Luengo (julianlm@decsai.ugr.es)
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see http://www.gnu.org/licenses/
**********************************************************************/
package keel.Algorithms.Genetic_Rule_Learning.M5Rules;
public final class M5 {
/** The root node */
private M5TreeNode m_root[];
/** No smoothing? */
private boolean m_UseUnsmoothed = false;
/** Pruning factor */
private double m_PruningFactor = 2;
/** Type of model */
private int m_Model = M5TreeNode.MODEL_TREE;
/** Verbosity */
private int m_Verbosity = 0;
public static final int MODEL_LINEAR_REGRESSION = M5TreeNode.LINEAR_REGRESSION;
public static final int MODEL_REGRESSION_TREE = M5TreeNode.REGRESSION_TREE;
public static final int MODEL_MODEL_TREE = M5TreeNode.MODEL_TREE;
MyDataset trainDataset, valDataset, testDataset;
/**
* Constructor by parameters file.
* @param paramFile the parsed parameters file
* @throws java.lang.Exception if the class for the dataset is not numeric.
*/
public M5(parseParameters paramFile) throws Exception{
//File Names
String trainFileName=paramFile.getTrainingInputFile();
String valFileName=paramFile.getValidationInputFile();
String testFileName=paramFile.getTestInputFile();
//Options
m_Model=MODEL_MODEL_TREE;
m_PruningFactor=Double.parseDouble(paramFile.getParameter(0)); //pruning factor (a in (n+a)/(n-k))
m_UseUnsmoothed=Boolean.valueOf(paramFile.getParameter(1)).booleanValue(); //whether the tree must be smoothed or not
m_Verbosity = Integer.parseInt(paramFile.getParameter(2)); //verbosity level
if (m_PruningFactor < 0 || m_PruningFactor > 10) {
m_PruningFactor = 2;
System.err.println("Error: Pruning Factor must be in the interval [0,10]");
System.err.println("Using default value: 2");
}
if (m_Verbosity < 0 || m_Verbosity > 2) {
m_Verbosity = 0;
System.err.println("Error: Verbosity must be 0, 1 or 2");
System.err.println("Using default value: 0");
}
/* Initializes the dataset. */
trainDataset = new MyDataset( trainFileName, true );
valDataset = new MyDataset( valFileName, false );
testDataset = new MyDataset( testFileName, false );
if (trainDataset.getClassAttribute().isDiscret()) {
throw new Exception("Class has to be numeric.");
}
// generate the tree
buildClassifier( trainDataset );
}
public M5(MyDataset data,double prune_factor,boolean unsmoothed,int verbosity) throws Exception{
//Options
m_Model=MODEL_MODEL_TREE;
m_PruningFactor=prune_factor; //pruning factor (a in (n+a)/(n-k))
m_UseUnsmoothed=unsmoothed; //whether the tree must be smoothed or not
m_Verbosity = verbosity; //verbosity level
if (m_PruningFactor < 0 || m_PruningFactor > 10) {
m_PruningFactor = 2;
System.err.println("Error: Pruning Factor must be in the interval [0,10]");
System.err.println("Using default value: 2");
}
if (m_Verbosity < 0 || m_Verbosity > 2) {
m_Verbosity = 0;
System.err.println("Error: Verbosity must be 0, 1 or 2");
System.err.println("Using default value: 0");
}
if (data.getClassAttribute().isDiscret()) {
throw new Exception("Class has to be numeric.");
}
// generate the tree
buildClassifier( data );
}
/**
* Construct a model tree by training itemsets
*
* @param inst training itemsets
* @exception Exception if the classifier can't be built
*/
public final void buildClassifier(MyDataset inst) throws Exception {
inst=inst.discretToBinary();
m_root = new M5TreeNode[2];
double deviation = stdDev(inst.getClassIndex(), inst);
m_root[0] = new M5TreeNode(inst, null,m_Model,m_PruningFactor,deviation); // build an empty tree
m_root[0].split(inst); // build the unpruned initial tree
m_root[0].numLeaves(0); // set tree leaves' number of the unpruned treee
m_root[1] = m_root[0].copy(null); // make a copy of the unpruned tree
m_root[1].prune(); // prune the tree
if (!m_UseUnsmoothed) {
m_root[1].smoothen(); // compute the smoothed linear models at the leaves
m_root[1].numLeaves(0); // set tree leaves' number of the pruned tree
}
}
/**
* Gets the current settings of the Classifier.
*
* @return an array of strings suitable for passing to setOptions
*/
public String[] getOptions() {
String[] options = new String[7];
int current = 0;
switch (m_Model) {
case MODEL_MODEL_TREE:
options[current++] = "-O";
options[current++] = "m";
if (m_UseUnsmoothed) {
options[current++] = "-U";
}
break;
case MODEL_REGRESSION_TREE:
options[current++] = "-O";
options[current++] = "r";
break;
case MODEL_LINEAR_REGRESSION:
options[current++] = "-O";
options[current++] = "l";
break;
}
options[current++] = "-F";
options[current++] = "" + m_PruningFactor;
options[current++] = "-V";
options[current++] = "" + m_Verbosity;
while (current < options.length) {
options[current++] = "";
}
return options;
}
/**
* Converts the output of the training process into a string
*
* @return the converted string
*/
public final String toString() {
try {
StringBuffer text = new StringBuffer();
double absDev = absDev(m_root[0].itemsets.getClassIndex(),m_root[0].itemsets);
if (m_Verbosity >= 1 && m_Model != M5TreeNode.LINEAR_REGRESSION) {
switch (m_root[0].model) {
case M5TreeNode.LINEAR_REGRESSION:
break;
case M5TreeNode.REGRESSION_TREE:
text.append("@Unpruned training regression tree:\n");
break;
case M5TreeNode.MODEL_TREE:
text.append("@Unpruned training model tree:\n");
break;
}
if (m_root[0].type == false) {
text.append("\n");
}
text.append(m_root[0].treeToString(0, absDev) + "\n");
text.append("@Models at the leaves:\n\n");
// the linear models at the leaves of the unpruned tree
text.append(m_root[0].formulaeToString(false) + "\n"); ;
}
if (m_root[0].model != M5TreeNode.LINEAR_REGRESSION) {
switch (m_root[0].model) {
case M5TreeNode.LINEAR_REGRESSION:
break;
case M5TreeNode.REGRESSION_TREE:
text.append("@Pruned training regression tree:\n");
break;
case M5TreeNode.MODEL_TREE:
text.append("@Pruned training model tree:\n");
break;
}
if (m_root[1].type == false) {
text.append("\n");
}
text.append(m_root[1].treeToString(0, absDev) + "\n"); //the pruned tree
text.append("@Models at the leaves:\n");
if ((m_root[0].model != M5TreeNode.LINEAR_REGRESSION) && (m_UseUnsmoothed)) {
text.append("@Unsmoothed linear models at the leaves of the pruned tree (simple):\n");
//the unsmoothed linear models at the leaves of the pruned tree
text.append(m_root[1].formulaeToString(false) + "\n");
}
if ((m_root[0].model == M5TreeNode.MODEL_TREE) && (!m_UseUnsmoothed)) {
text.append("@Smoothed linear models at the leaves of the pruned tree (complex):\n");
text.append(m_root[1].formulaeToString(true) + "\n");
//the smoothed linear models at the leaves of the pruned tree
}
} else {
text.append("@Training linear regression model:\n");
text.append(m_root[1].unsmoothed.toString(m_root[1].itemsets,0) + "\n\n");
// print the linear regression model
}
text.append("@Number of Rules: " + m_root[1].numberOfLinearModels());
return text.toString();
} catch (Exception e) {
return "can't print m5' tree";
}
}
/**
* return the number of linear models
* @return the number of linear models
*/
public double measureNumLinearModels() {
return m_root[1].numberOfLinearModels();
}
/**
* return the number of leaves in the tree
* @return the number leaves in the tree (same as # linear models &
* # rules)
*/
public double measureNumLeaves() {
return measureNumLinearModels();
}
/**
* return the number of rules
* @return the number of rules (same as # linear models &
* # leaves in the tree)
*/
public double measureNumRules() {
return measureNumLinearModels();
}
/**
* Returns the value of the named measure
* @param additionalMeasureName the name of the measure to query for its value
* @return the value of the named measure
* @exception IllegalArgumentException if the named measure is not supported
*/
public double getMeasure(String additionalMeasureName) {
if (additionalMeasureName.compareTo("measureNumRules") == 0) {
return measureNumRules();
} else if (additionalMeasureName.compareTo("measureNumLinearModels") ==
0) {
return measureNumLinearModels();
} else if (additionalMeasureName.compareTo("measureNumLeaves") == 0) {
return measureNumLeaves();
} else {
throw new IllegalArgumentException(additionalMeasureName
+ " not supported (M5)");
}
}
/**
* Get the value of UseUnsmoothed.
*
* @return Value of UseUnsmoothed.
*/
public boolean getUseUnsmoothed() {
return m_UseUnsmoothed;
}
/**
* Set the value of UseUnsmoothed.
*
* @param v Value to assign to UseUnsmoothed.
*/
public void setUseUnsmoothed(boolean v) {
m_UseUnsmoothed = v;
}
/**
* Get the value of PruningFactor.
*
* @return Value of PruningFactor.
*/
public double getPruningFactor() {
return m_PruningFactor;
}
/**
* Set the value of PruningFactor.
*
* @param v Value to assign to PruningFactor.
*/
public void setPruningFactor(double v) {
m_PruningFactor = v;
}
public M5TreeNode getTree(){
//if (m_UseUnsmoothed)
// return m_root[0];
//else
return m_root[1];
}
/**
* Get the value of Verbosity.
*
* @return Value of Verbosity.
*/
public int getVerbosity() {
return m_Verbosity;
}
/**
* Set the value of Verbosity.
*
* @param v Value to assign to Verbosity.
*/
public void setVerbosity(int v) {
m_Verbosity = v;
}
/**
* Tests if enumerated attribute(s) exists in the itemsets
* @param inst itemsets
* @return true if there is at least one; false if none
*/
public final static boolean hasEnumAttr(MyDataset inst) {
int j;
boolean b = false;
for (j = 0; j < inst.numAttributes(); j++) {
if (inst.getAttribute(j).isDiscret() == true) {
b = true;
}
}
return b;
}
/**
* Tests if missing value(s) exists in the itemsets
* @param inst itemsets
* @return true if there is missing value(s); false if none
*/
public final static boolean hasMissing(MyDataset inst) {
int i, j;
boolean b = false;
for (i = 0; i < inst.numItemsets(); i++) {
for (j = 0; j < inst.numAttributes(); j++) {
if (inst.itemset(i).isMissing(j) == true) {
b = true;
}
}
}
return b;
}
/**
* Returns the sum of the itemsets values of an attribute
* @param attr an attribute
* @param inst itemsets
* @return the sum value
*/
public final static double sum(int attr, MyDataset inst) {
int i;
double sum = 0.0;
for (i = 0; i <= inst.numItemsets() - 1; i++) {
sum += inst.itemset(i).getValue(attr);
}
return sum;
}
/**
* Returns the squared sum of the itemsets values of an attribute
* @param attr an attribute
* @param inst itemsets
* @return the squared sum value
*/
public final static double sqrSum(int attr, MyDataset inst) {
int i;
double sqrSum = 0.0, value;
for (i = 0; i <= inst.numItemsets() - 1; i++) {
value = inst.itemset(i).getValue(attr);
sqrSum += value * value;
}
return sqrSum;
}
/**
* Returns the standard deviation value of the itemsets values of an attribute
* @param attr an attribute
* @param inst itemsets
* @return the standard deviation value
*/
public final static double stdDev(int attr, MyDataset inst) {
int i, count = 0;
double sd, va, sum = 0.0, sqrSum = 0.0, value;
for (i = 0; i <= inst.numItemsets() - 1; i++) {
count++;
value = inst.itemset(i).getValue(attr);
sum += value;
sqrSum += value * value;
}
if (count > 1) {
va = (sqrSum - sum * sum / count) / count;
va = Math.abs(va);
sd = Math.sqrt(va);
} else {
sd = 0.0;
}
return sd;
}
/**
* Returns the absolute deviation value of the itemsets values of an attribute
* @param attr an attribute
* @param inst itemsets
* @return the absolute deviation value
*/
public final static double absDev(int attr, MyDataset inst) {
int i;
double average = 0.0, absdiff = 0.0, absDev;
for (i = 0; i <= inst.numItemsets() - 1; i++) {
average += inst.itemset(i).getValue(attr);
}
if (inst.numItemsets() > 1) {
average /= (double) inst.numItemsets();
for (i = 0; i <= inst.numItemsets() - 1; i++) {
absdiff += Math.abs(inst.itemset(i).getValue(attr) - average);
}
absDev = absdiff / (double) inst.numItemsets();
} else {
absDev = 0.0;
}
return absDev;
}
/**
* Returns the variance value of the itemsets values of an attribute
* @param attr an attribute
* @param inst itemsets
* @return the variance value
*/
public final static double variance(int attr, MyDataset inst) {
int i, count = 0;
double value, sum = 0.0, sqrSum = 0.0, va;
for (i = 0; i <= inst.numItemsets() - 1; i++) {
value = inst.itemset(i).getValue(attr);
sum += value;
sqrSum += value * value;
count++;
}
if (count > 0) {
va = (sqrSum - sum * sum / count) / count;
} else {
va = 0.0;
}
return va;
}
/**
* Rounds a double
* @param value the double value
* @return the double rounded
*/
public final static long roundDouble(double value) {
long roundedValue;
roundedValue = value > 0 ? (long) (value + 0.5) :
-(long) (Math.abs(value) + 0.5);
return roundedValue;
}
/**
* Returns the largest (closest to positive infinity) long integer value that is not greater than the argument.
* @param value the double value
* @return the floor integer
*/
public final static long floorDouble(double value) {
long floorValue;
floorValue = value > 0 ? (long) value : -(long) (Math.abs(value) + 1);
return floorValue;
}
/**
* Rounds a double and converts it into a formatted right-justified String.
* It is like %f format in C language.
* @param value the double value
* @param width the width of the string
* @param afterDecimalPoint the number of digits after the decimal point
* @return the double as a formatted string
*/
public final static String doubleToStringF(double value, int width,
int afterDecimalPoint) {
StringBuffer stringBuffer;
String resultString;
double temp;
int i, dotPosition;
long precisionValue;
if (afterDecimalPoint < 0) {
afterDecimalPoint = 0;
}
precisionValue = 0;
temp = value * Math.pow(10.0, afterDecimalPoint);
if (Math.abs(temp) < Long.MAX_VALUE) {
precisionValue = roundDouble(temp);
if (precisionValue == 0) {
resultString = String.valueOf(0);
stringBuffer = new StringBuffer(resultString);
stringBuffer.append(".");
for (i = 1; i <= afterDecimalPoint; i++) {
stringBuffer.append("0");
}
resultString = stringBuffer.toString();
} else {
resultString = String.valueOf(precisionValue);
stringBuffer = new StringBuffer(resultString);
dotPosition = stringBuffer.length() - afterDecimalPoint;
while (dotPosition < 0) {
stringBuffer.insert(0, 0);
dotPosition++;
}
stringBuffer.insert(dotPosition, ".");
if (stringBuffer.charAt(0) == '.') {
stringBuffer.insert(0, 0);
}
resultString = stringBuffer.toString();
}
} else {
resultString = new String("NaN"); ;
}
// Fill in space characters.
stringBuffer = new StringBuffer(Math.max(width, resultString.length()));
for (i = 0; i < stringBuffer.capacity() - resultString.length(); i++) {
stringBuffer.append(' ');
}
stringBuffer.append(resultString);
return stringBuffer.toString();
}
/**
* Rounds a double and converts it into a formatted right-justified String. If the double is not equal to zero and not in the range [10e-3,10e7] it is returned in scientific format.
* It is like %g format in C language.
* @param value the double value
* @param width the width of the string
* @param precision the number of valid digits
* @return the double as a formatted string
*/
public final static String doubleToStringG(double value, int width,
int precision) {
StringBuffer stringBuffer;
String resultString;
double temp;
int i, dotPosition, exponent = 0;
long precisionValue;
if (precision <= 0) {
precision = 1;
}
precisionValue = 0;
exponent = 0;
if (value != 0.0) {
exponent = (int) floorDouble(Math.log(Math.abs(value)) /
Math.log(10));
temp = value * Math.pow(10.0, precision - exponent - 1);
precisionValue = roundDouble(temp); // then output value = precisionValue * pow(10,exponent+1-precision)
if (precision - 1 !=
(int) (Math.log(Math.abs(precisionValue) + 0.5) / Math.log(10))) {
exponent++;
precisionValue = roundDouble(precisionValue / 10.0);
}
}
if (precisionValue == 0) { // value = 0.0
resultString = String.valueOf("0");
} else {
if (precisionValue >= 0) {
dotPosition = 1;
} else {
dotPosition = 2;
}
if (exponent < -3 || precision - 1 + exponent > 7) { // Scientific format.
resultString = String.valueOf(precisionValue);
stringBuffer = new StringBuffer(resultString);
stringBuffer.insert(dotPosition, ".");
stringBuffer = deleteTrailingZerosAndDot(stringBuffer);
stringBuffer.append("e").append(String.valueOf(exponent));
resultString = stringBuffer.toString();
} else { //
resultString = String.valueOf(precisionValue);
stringBuffer = new StringBuffer(resultString);
for (i = 1; i <= -exponent; i++) {
stringBuffer.insert(dotPosition - 1, "0");
}
if (exponent <= -1) {
stringBuffer.insert(dotPosition, ".");
} else if (exponent <= precision - 1) {
stringBuffer.insert(dotPosition + exponent, ".");
} else {
for (i = 1; i <= exponent - (precision - 1); i++) {
stringBuffer.append("0");
}
stringBuffer.append(".");
}
// deleting trailing zeros and dot
stringBuffer = deleteTrailingZerosAndDot(stringBuffer);
resultString = stringBuffer.toString();
}
}
// Fill in space characters.
stringBuffer = new StringBuffer(Math.max(width, resultString.length()));
for (i = 0; i < stringBuffer.capacity() - resultString.length(); i++) {
stringBuffer.append(' ');
}
stringBuffer.append(resultString);
return stringBuffer.toString();
}
/**
* Deletes the trailing zeros and decimal point in a stringBuffer
* @param stringBuffer string buffer
* @return string buffer with deleted trailing zeros and decimal point
*/
public final static StringBuffer deleteTrailingZerosAndDot(StringBuffer
stringBuffer) {
while (stringBuffer.charAt(stringBuffer.length() - 1) == '0' ||
stringBuffer.charAt(stringBuffer.length() - 1) == '.') {
if (stringBuffer.charAt(stringBuffer.length() - 1) == '0') {
stringBuffer.setLength(stringBuffer.length() - 1);
} else {
stringBuffer.setLength(stringBuffer.length() - 1);
break;
}
}
return stringBuffer;
}
/**
* Returns the smoothed values according to the smoothing formula (np+kq)/(n+k)
* @param p a double, normally is the prediction of the model at the current node
* @param q a double, normally is the prediction of the model at the up node
* @param n the number of itemsets at the up node
* @param k the smoothing constance, default =15
* @return the smoothed value
*/
public final static double smoothenValue(double p, double q, int n, int k) {
return (n * p + k * q) / (double) (n + k);
}
/**
* Returns the correlation coefficient of two double vectors
* @param y1 double vector 1
* @param y2 double vector 2
* @param n the length of two double vectors
* @return the correlation coefficient
*/
public final static double correlation(double y1[], double y2[], int n) {
int i;
double av1 = 0.0, av2 = 0.0, y11 = 0.0, y22 = 0.0, y12 = 0.0, c;
if (n <= 1) {
return 1.0;
}
for (i = 0; i < n; i++) {
av1 += y1[i];
av2 += y2[i];
}
av1 /= (double) n;
av2 /= (double) n;
for (i = 0; i < n; i++) {
y11 += (y1[i] - av1) * (y1[i] - av1);
y22 += (y2[i] - av2) * (y2[i] - av2);
y12 += (y1[i] - av1) * (y2[i] - av2);
}
if (y11 * y22 == 0.0) {
c = 1.0;
} else {
c = y12 / Math.sqrt(Math.abs(y11 * y22));
}
return c;
}
/**
* Tests if two double values are equal to each other
* @param a double 1
* @param b double 2
* @return true if equal; false if not equal
*/
public final static boolean eqDouble(double a, double b) {
if (Math.abs(a) < 1e-10 && Math.abs(b) < 1e-10) {
return true;
}
double c = Math.abs(a) + Math.abs(b);
if (Math.abs(a - b) < c * 1e-10) {
return true;
} else {
return false;
}
}
/**
* Prints error message and exits
* @param err error message
*/
public final static void errorMsg(String err) {
System.out.print("Error: ");
System.out.println(err);
System.exit(1);
}
}