/***********************************************************************
This file is part of KEEL-software, the Data Mining tool for regression,
classification, clustering, pattern mining and so on.
Copyright (C) 2004-2010
F. Herrera (herrera@decsai.ugr.es)
L. S�nchez (luciano@uniovi.es)
J. Alcal�-Fdez (jalcala@decsai.ugr.es)
S. Garc�a (sglopez@ujaen.es)
A. Fern�ndez (alberto.fernandez@ujaen.es)
J. Luengo (julianlm@decsai.ugr.es)
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see http://www.gnu.org/licenses/
**********************************************************************/
/**
* <p>
* @author Written by Manuel Moreno (Universidad de C�rdoba) 01/07/2008
* @version 0.1
* @since JDK 1.5
*</p>
*/
package keel.Algorithms.Decision_Trees.CART;
import java.util.ArrayList;
import java.util.Arrays;
import keel.Algorithms.Decision_Trees.CART.impurities.IImpurityFunction;
import keel.Algorithms.Decision_Trees.CART.tree.DecisionTree;
import keel.Algorithms.Decision_Trees.CART.tree.TreeNode;
import keel.Algorithms.Neural_Networks.NNEP_Common.data.DoubleTransposedDataSet;
/**
* Main class of algorithm CART: Classification And Regression Trees (Breiman and al., 1984) CART are binary trees
*
*/
public class CART
{
/** CART Tree) */
private DecisionTree tree;
/** Maximum Depth */
private int maxDepth;
/** Regression flag.
* This flag is true when dealing with regression problems.
* False for classification problems.
*/
private boolean regression;
/** Impurity function. Usually used Gini, Twoing functions */
private IImpurityFunction impurityFunction;
/** Building tree data set */
private DoubleTransposedDataSet dataset;
/////////////////////////////////////////////////////////////////////
// ------------------------------------------------------ Constructor
/////////////////////////////////////////////////////////////////////
/**
* Default constructor
*
* @param dataset Dataset to learn
*/
public CART (DoubleTransposedDataSet dataset) {
this.dataset = dataset;
}
/**
* Constructor with impurity function
*
* @param dataset Dataset to learn
* @param impurityFunction the impurity function
*/
public CART (DoubleTransposedDataSet dataset, IImpurityFunction impurityFunction) {
this.dataset = dataset;
// Set impurity function
this.impurityFunction = impurityFunction;
this.impurityFunction.setDataset(dataset);
}
/////////////////////////////////////////////////////////////////////
// ---------------------------------------------- Getters and Setters
/////////////////////////////////////////////////////////////////////
/**
* It returns the decision tree
*
* @return the decision tree
*/
public DecisionTree getTree()
{
return tree;
}
/**
* It returns the impurity function
*
* @return the impurityFunction
*/
public IImpurityFunction getImpurityFunction()
{
return impurityFunction;
}
/**
*
* It sets the impurity function
*
* @param impurityFunction the impurityFunction to set
*/
public void setImpurityFunction(IImpurityFunction impurityFunction)
{
this.impurityFunction = impurityFunction;
this.impurityFunction.setDataset(dataset);
}
/**
* It returns the maximal depth
*
* @return the maxDepth
*/
public int getMaxDepth() {
return maxDepth;
}
/**
* It sets the maximal depth
*
* @param maxDepth the maxDepth to set
*/
public void setMaxDepth(int maxDepth) {
this.maxDepth = maxDepth;
}
/**
* Returns if we are dealing with a regression problem
*
* @return the regression
*/
public boolean isRegression() {
return regression;
}
/**
* It sets if we are dealing with a regression problem
*
* @param regression the regression to set
*/
public void setRegression(boolean regression) {
this.regression = regression;
}
/////////////////////////////////////////////////////////////////////
// ---------------------------------------------------------- Methods
/////////////////////////////////////////////////////////////////////
/**
* This function find the best possible values for splitting
*
* @param patterns pattern indexes
* @return the best possible values for splitting
*/
private double[][] splittingValues(int [] patterns)
{
int ninputs = dataset.getNofinputs();
int npatterns = patterns.length;
// Reserve memory for result
double [][] splittingValues = new double[ninputs][npatterns-1];
//computes medians values from the first one to the last but one
for (int j=0; j<ninputs; j++) {
// Get all values for input j
double [] aux = dataset.getObservationsOf(j);
// Get a copy of values in order to avoid damages in data set
double [] x_j = new double[npatterns];
for (int i=0; i<npatterns; i++) {
int patternIndex = patterns[i];
x_j[i] = aux[patternIndex];
}
// Sort values in vector x_j from min to max value
Arrays.sort(x_j);
// get splitting values as the middle of adjacent values (Xi + Xi+1)/2
for (int i=0; i<x_j.length-1; i++) {
splittingValues[j][i] = (x_j[i] + x_j[i+1])/2;
}
}
return splittingValues;
}
/**
* Constructs decision tree
*/
public void build_tree ()
{
// Create tree
tree = new DecisionTree();
// Create root node
// Root node contains all patterns in data set
int [] patterns = new int[dataset.getNofobservations()];
for (int i=0; i<patterns.length; i++)
patterns[i]=i;//each index points to each pattern in data set
TreeNode root = new TreeNode(null,patterns);
// Set root node of the tree
tree.setRoot(root);
// Make tree grow
grow(root);
}
/**
* This is a recursive function that receive a node and check if it can be split.
* If true, it adds sons, and try to grow them.
* @param node the node to check
*/
private void grow(TreeNode node)
{
if (node == null) // Check if node is null
return;
else {
if (stopCriteria(node)) {// Check stop criteria
if (regression)
assignMean(node); // Assign its output value
else
assignClass(node); // Assign its output class.
// Stop building the tree
return;
}
else {
splitNode(node); // Split node
if (regression)
assignMean(node); // Assign its output value
else
assignClass(node); // Assign its output class.
// TODO a�adir asignar media
grow(node.getLeftSon()); // grow left node
grow(node.getRightSon()); // grow right node
}
}
}
/**
* It splits a node into two sons
*
* @param node node to split into two sons
*/
private void splitNode(TreeNode node)
{
//double time = System.currentTimeMillis();
// Consider each variable x_j at a time
int ninputs = dataset.getNofinputs();
int npatterns = node.getPatterns().length;
double [][] gains = new double[ninputs][npatterns-1];
int bestSplit_i=0;
int bestSplit_j=0;
// Assign current node impurities (this increases the performance)
try {
node.setImpurities(impurityFunction.impurities(node.getPatterns(), 1));
} catch (Exception e) {
e.printStackTrace();
}
// For each input variable x_j calculate possible splitting values x*_j
// as the middle of adjacent values (x^i_j + x^i+1_j)/2
double [][] splittingValues = splittingValues(node.getPatterns());
for (int j=0; j<ninputs; j++) {
// Among all questions x_j <= x*j choose the "best" (highest
// change of impurity)
// Check every split value
for (int i=0; i<npatterns-1; i++) {
// Compute gain for current node and input variable j for given value
gains[j][i] = computeImpuritiesGain(node, j, splittingValues[j][i]);
//System.out.println("Gain["+j+"]["+i+"]="+gains[j][i]);
if (gains[j][i] >= gains[bestSplit_j][bestSplit_i]) { // Gain -> max
bestSplit_i = i;
bestSplit_j = j;
}
}
}
// Set variable and value
node.setVariable(bestSplit_j);
node.setValue(splittingValues[bestSplit_j][bestSplit_i]);
// Split node in two sons
ArrayList<int []> arrays = dividePatterns(node); // Fill patterns toLeft and toRight
int [] toLeft = arrays.get(0);
int [] toRight = arrays.get(1);
// Create sons
TreeNode leftSon = new TreeNode(node, toLeft);
TreeNode rightSon = new TreeNode(node, toRight);
// Link sons
node.setLeftSon(leftSon);
node.setRightSon(rightSon);
//System.out.println("Split node: "+(System.currentTimeMillis()-time)+"ms");
}
/**
* This function assign output class to a node.
* Its class label is that of the majority class in that node patterns
*
* @param node tree node to assign its output class
*/
private void assignClass(TreeNode node)
{
//double time = System.currentTimeMillis();
// Patterns in current node
int [] patterns = node.getPatterns();
// Data set outputs
double [][] outputs = dataset.getAllOutputs();
// Counter of patterns in each class
int [] patternsInClass = new int [outputs.length];
// Determine majority class in current node
for (int i=0; i<outputs.length; i++) { // For each class
for (int j=0; j<patterns.length; j++) { // For each pattern
int patternIndex = patterns[j];
if ( outputs[i][patternIndex] == 1.0) // if patterns owns to class i
patternsInClass[i]++;
}
}
// Find majority class
int majorityClass = 0;
for (int i=1; i<patternsInClass.length; i++) {
if (patternsInClass[i] > patternsInClass[majorityClass])
majorityClass = i;
}
// Assign majority class to the node
node.setOutputClass(majorityClass);
//System.out.println("Asign Classes: "+(System.currentTimeMillis()-time)+"ms");
}
/**
* Assign predicted value as the mean of the output
* from each pattern in this node
*
* @param node TreeNode to compute
*/
private void assignMean(TreeNode node)
{
// Patterns in current node
int [] patterns = node.getPatterns();
// Data set outputs
double [] outputs = dataset.getOutput(0);
// Compute mean
double mean = 0;
for (int i=0; i<patterns.length; i++) {
int patternIndex = patterns[i];
mean += outputs[patternIndex];
}
mean = mean/patterns.length;
// Assign mean
node.setOutputValue(mean);
}
/**
* This function divides patterns associated to a node using its variable and split value into two
* groups of patterns depending on the condition (variable <= splitValue)
*
* @param from Node to split in two branches. It must contain variable, split value and associated patterns.
* @return Return toLeft This parameter will be deleted!. It will contain the patterns on left branch. toRight This parameter will be deleted!. It will contain the patterns on right branch
*/
private ArrayList<int[]> dividePatterns(TreeNode from)
{
//double time = System.currentTimeMillis();
int [] patterns = from.getPatterns();
int variable = from.getVariable();
double limitValue = from.getValue();
ArrayList<Integer> leftBranch = new ArrayList<Integer>();
ArrayList<Integer> rightBranch = new ArrayList<Integer>();
// Divide patterns using condition variable <= value
for (int j=0; j< patterns.length; j++) {
int patternIndex = patterns[j];
double patternValue = dataset.getAllInputs()[variable][patternIndex];
// This pattern goes to left or right branch?
if (patternValue <= limitValue)
leftBranch.add(patternIndex);
else
rightBranch.add(patternIndex);
}
// Convert into arrays
int [] toLeft = new int [leftBranch.size()];
for (int i=0; i<toLeft.length; i++)
toLeft[i] = leftBranch.get(i);
int [] toRight = new int [rightBranch.size()];
for (int i=0; i<toRight.length; i++)
toRight[i] = rightBranch.get(i);
// Construct a List for result
ArrayList<int[]> result = new ArrayList<int[]>();
result.add(toLeft);
result.add(toRight);
//System.out.println("Divide Patterns: "+(System.currentTimeMillis()-time)+"ms");
return result;
}
/**
* This function calculates the impurities variance between
* parent (current node) and both sons
*
* @param node Current node
* @param inputvar Input Data set variable index
* @param limitValue Limit value to compare
* @return impurities gain
*/
private double computeImpuritiesGain(TreeNode node, int inputvar, double limitValue)
{
ArrayList<Integer> leftBranch = new ArrayList<Integer>();
ArrayList<Integer> rightBranch = new ArrayList<Integer>();
int [] patterns = node.getPatterns();
// For each pattern in parent node
for (int j=0; j< patterns.length; j++) {
int patternIndex = patterns[j];
double patternValue = dataset.getAllInputs()[inputvar][patterns[j]];
// This pattern goes to left or right branch?
if (patternValue <= limitValue)
leftBranch.add(patternIndex);
else
rightBranch.add(patternIndex);
}
// Compute right, left and parent impurities in order
// to obtain the gain
int [] leftPatterns = new int [leftBranch.size()];
int [] rightPatterns = new int [rightBranch.size()];
double parentImpurities = 0f;
double leftImpurities = 0f;
double rightImpurities = 0f;
try { // Impurities functions can throw exceptions
// obtain impurities in left branch (using cost 1)
for (int i=0; i<leftPatterns.length; i++)
leftPatterns[i] = leftBranch.get(i);
leftImpurities = impurityFunction.impurities(leftPatterns, 1);
// obtain impurities in right branch
for (int i=0; i<rightPatterns.length; i++)
rightPatterns[i] = rightBranch.get(i);
rightImpurities = impurityFunction.impurities(rightPatterns, 1);
// obtain impurities of current node
// parentImpurities = impurityFunction.impurities(patterns, 1);
parentImpurities = node.getImpurities();
} catch (Exception e) {
e.printStackTrace();
}
// return i(t) - P_l*i(t_l) - P_r*i(t_r)
double P_l = leftPatterns.length/(double)patterns.length;
double P_r = rightPatterns.length/(double)patterns.length;
// System.out.println("Gain: "+parentImpurities+"-("+P_l+"*"+leftImpurities+" + "+P_r+"*"+rightImpurities+")");
return ( parentImpurities - (P_l*leftImpurities) - (P_r*rightImpurities));
}
/**
* Prune decision tree
*/
public void prune_tree() {
// TODO A prune method can be used
}
/**
* It checks if the stop criteria has been reached
*
* @return true if stop criteria has been reached. False otherwise
*/
public boolean stopCriteria(TreeNode node) {
//double time = System.currentTimeMillis();
int [] patterns = node.getPatterns();
// If a node have only one pattern (needed for next criteria)
if (patterns.length < 2)
return true;
// If tree depth reaches user-specified limit
if(tree.depth() >= maxDepth)
return true;
// If a node becomes pure
// (all cases in a node have identical values of the dependent variable)
boolean equalDependant = true;
for (int i=0; i<patterns.length-1; i++) {
int patternIndex = patterns[i];
int nextPatternIndex = patterns[i+1]; //Be sure there is more than one pattern
double [] prev_output = dataset.getOutputs(patternIndex);
double [] next_output = dataset.getOutputs(nextPatternIndex);
equalDependant = Arrays.equals(prev_output, next_output);
if (!equalDependant) //Case any difference, break
break;
}
if (equalDependant) // Case all outputs are equals
return true;
// TODO Other possible stop criteria
// If all cases in a node have identical values for each predict
// If the size of a node is less than the user-specified minimum node size value
// If the split of a node result in a child whose node size is less than the
// user-specified minimum
//System.out.println("Stop Criteria: "+(System.currentTimeMillis()-time)+"ms");
// Otherwise
return false;
}
/**
* It gets the classification results
*
* @param dataset used for checking error in
* @return error produced applying the data set given as argument
*
*/
public byte[][] getClassificationResults(DoubleTransposedDataSet dataset)
{
double [][] inputs = transposedMatrix(dataset.getAllInputs());
int noutputs = dataset.getNofoutputs();
int npatterns = dataset.getNofobservations();
// Result matrix with predicted values
byte [][] predicted = new byte [noutputs][npatterns];
// For each pattern determine if it is correctly classified
TreeNode root = tree.getRoot();
for (int i=0; i<npatterns; i++) {
double [] pattern = inputs[i];
int predictedClass = (int) root.evaluate(pattern, regression);
// Initialize values
for (int j=0; j<noutputs; j++) {
predicted[j][i] = 0;
}
// Check if prediction is correct
predicted[predictedClass][i] = 1;
}
// return CCR
return predicted;
}
/**
*
* It gets the regression results
*
* @param dataset used for checking error in
* @return error produced applying the data set given as argument
*
*/
public double[] getRegressionResults(DoubleTransposedDataSet dataset)
{
double [][] inputs = transposedMatrix(dataset.getAllInputs());
int npatterns = dataset.getNofobservations();
// Result matrix with predicted values
double [] predicted = new double [npatterns];
// For each pattern determine if it is correctly classified
TreeNode root = tree.getRoot();
for (int i=0; i<npatterns; i++) {
double [] pattern = inputs[i];
double predictedValue = root.evaluate(pattern, regression);
// Check if prediction is correct
predicted[i] = predictedValue;
}
// return CCR
return predicted;
}
/**
*
* It returns the transposed matrix of a given one
*
* @param a input matrix
* @return transposed matrix of a
*/
private double [][] transposedMatrix(double [][] a)
{
int rows = a.length;
int cols = a[rows-1].length;
double[][] b = new double [cols][rows];
for (int i=0; i< rows; i++) {
for (int j=0; j<cols; j++) {
b[j][i] = a[i][j];
}
}
return b;
}
}