/***********************************************************************
This file is part of KEEL-software, the Data Mining tool for regression,
classification, clustering, pattern mining and so on.
Copyright (C) 2004-2010
F. Herrera (herrera@decsai.ugr.es)
L. Sánchez (luciano@uniovi.es)
J. Alcalá-Fdez (jalcala@decsai.ugr.es)
S. García (sglopez@ujaen.es)
A. Fernández (alberto.fernandez@ujaen.es)
J. Luengo (julianlm@decsai.ugr.es)
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see http://www.gnu.org/licenses/
**********************************************************************/
/**
* <p>
* @author Written by Cristobal Romero (Universidad de C�rdoba) 10/10/2007
* @version 0.1
* @since JDK 1.5
*</p>
*/
package keel.Algorithms.Semi_Supervised_Learning.Basic.C45;
import java.util.*;
/**
* Class to implement the calculus of the cut point
*/
public class Cut {
/** Classification of class values. */
protected Classification classification;
/** Number of subsets. */
protected int numSubsets;
/** Number of branches. */
private int nBranches;
/** Attribute to cut on. */
private int attributeIndex;
/** Minimum number of itemsets per leaf. */
private int minItemsets;
/** Cut point. */
private double cutPoint;
/** Information gain of cut. */
private double infoGain;
/** Gain ratio of cut. */
private double gainRatio;
/** The sum of the weights of the itemsets. */
private double sumOfWeights;
/** Number of cut points. */
private int nCuts;
/** Function to initialize the cut model.
*
* @param index The attribute index.
* @param nObj Minimum number of itemsets.
* @param weights The weight of all the itemsets.
*/
public Cut(int index, int nObj, double weights) {
// Get index of attribute to cut on.
attributeIndex = index;
// Set minimum number of objects.
minItemsets = nObj;
// Set the sum of the weights
sumOfWeights = weights;
}
/** Function to use when no cut is necessary.
*
* @param dist Distribution of values per class.
*/
public Cut(Classification dist) {
classification = new Classification(dist);
numSubsets = 1;
}
/** Function to create the cut point.
*
* @param trainItemsets The dataset to classify.
*
* @throws Exception If the classification cannot be made.
*/
public void classify(Dataset trainItemsets) throws Exception {
if (numSubsets == 1) {
classification = new Classification(trainItemsets);
} else {
// Initialize the remaining itemset variables.
numSubsets = 0;
cutPoint = Double.MAX_VALUE;
infoGain = 0;
gainRatio = 0;
// Different treatment for enumerated and numeric
// attributes.
if (trainItemsets.getAttribute(attributeIndex).isDiscret()) {
if (nBranches != 2) {
nBranches = trainItemsets.getAttribute(attributeIndex).
numValues();
nCuts = nBranches;
} else {
nCuts = 0;
}
cutDiscret(trainItemsets);
} else {
nCuts = 0;
trainItemsets.sort(attributeIndex);
cutContinuous(trainItemsets);
}
}
}
/** Function to compute the probability for itemset.
*
* @param classIndex The index of the class.
* @param itemset The itemset.
* @param subset The index of the subset.
*
* @return The probability computed.
*/
public final double classProbability(int classIndex, Itemset itemset,
int subset) {
if (numSubsets == 1) {
if (subset > -1) {
return classification.probability(classIndex, subset);
} else {
double[] weights = weights(itemset);
if (weights == null) {
return classification.probability(classIndex);
} else {
double prob = 0;
for (int i = 0; i < weights.length; i++) {
prob += weights[i] *
classification.probability(classIndex, i);
}
return prob;
}
}
} else {
if (subset <= -1) {
double[] weights = weights(itemset);
if (weights == null) {
return classification.probability(classIndex);
} else {
double prob = 0;
for (int i = 0; i < weights.length; i++) {
prob += weights[i] *
classification.probability(classIndex, i);
}
return prob;
}
} else {
if (classification.perValue(subset) > 0) {
return classification.probability(classIndex, subset);
} else {
if (classification.maxClass() == classIndex) {
return 1;
} else {
return 0;
}
}
}
}
}
/** Function to create the cut on continuous attributes.
*
* @param trainItemsets The dataset used to compute the cut.
*/
private void cutContinuous(Dataset trainItemsets) {
int firstMiss, next = 1, last = 0, cutIndex = -1, i;
double currentInfoGain, defaultEnt, minCut;
Itemset itemset;
// Current attribute is a numeric attribute.
classification = new Classification(2, trainItemsets.numClasses());
// Only Dataset with known values are relevant.
Enumeration enum2 = trainItemsets.enumerateItemsets();
i = 0;
while (enum2.hasMoreElements()) {
itemset = (Itemset) enum2.nextElement();
if (itemset.isMissing(attributeIndex)) {
break;
}
classification.add(1, itemset);
i++;
}
firstMiss = i;
// Compute minimum number of Dataset required in each subset.
minCut = 0.1 * (classification.getTotal()) /
((double) trainItemsets.numClasses());
if (minCut <= minItemsets) {
minCut = minItemsets;
} else if (minCut > 25) {
minCut = 25;
}
// Enough Dataset with known values?
if ((double) firstMiss < 2 * minCut) {
return;
}
// Compute values of criteria for all possible cut indices.
defaultEnt = oldEntropy(classification);
while (next < firstMiss) {
if (trainItemsets.itemset(next - 1).getValue(attributeIndex)
+ 1e-5 < trainItemsets.itemset(next).getValue(attributeIndex)) {
// Move class values for all Dataset up to next
// possible cut point.
classification.shiftRange(1, 0, trainItemsets, last, next);
// Check if enough Dataset in each subset and compute
// values for criteria.
if (classification.perValue(0) >= minCut &&
classification.perValue(1) >= minCut) {
currentInfoGain = infoGainCutCrit(
classification, sumOfWeights, defaultEnt);
if (currentInfoGain > infoGain) {
infoGain = currentInfoGain;
cutIndex = next - 1;
}
nCuts++;
}
last = next;
}
next++;
}
// Was there any useful cut?
if (nCuts == 0) {
return;
}
// Compute modified information gain for best cut.
infoGain = infoGain - ((Math.log(nCuts) / Math.log(2)) / sumOfWeights);
if (infoGain <= 0) {
return;
}
// Set itemset variables' values to values for best cut.
numSubsets = 2;
cutPoint = (trainItemsets.itemset(cutIndex + 1).getValue(attributeIndex) +
trainItemsets.itemset(cutIndex).getValue(attributeIndex)) /
2;
// Restore classification for best cut.
classification = new Classification(2, trainItemsets.numClasses());
classification.addRange(0, trainItemsets, 0, cutIndex + 1);
classification.addRange(1, trainItemsets, cutIndex + 1, firstMiss);
// Compute modified gain ratio for best cut.
gainRatio = gainRatioCutCrit(classification, sumOfWeights, infoGain);
}
/** Function to create the cut on discret attributes.
*
* @param trainItemsets The dataset used to compute the cut.
*/
private void cutDiscret(Dataset trainItemsets) {
Itemset itemset;
classification = new Classification(nBranches, trainItemsets.numClasses());
// Only Dataset with known values are relevant.
Enumeration enum2 = trainItemsets.enumerateItemsets();
while (enum2.hasMoreElements()) {
itemset = (Itemset) enum2.nextElement();
if (!itemset.isMissing(attributeIndex)) {
classification.add((int) itemset.getValue(attributeIndex),
itemset);
}
}
// Check if minimum number of Dataset in at least two subsets.
if (classification.check(minItemsets)) {
numSubsets = nBranches;
infoGain = infoGainCutCrit(classification, sumOfWeights,
oldEntropy(classification));
gainRatio = gainRatioCutCrit(classification, sumOfWeights, infoGain);
}
}
/** Function to set the cut point.
*
* @param allItemsets The dataset used for the cut.
*/
public final void setCutPoint(Dataset allItemsets) {
double newCutPoint = -Double.MAX_VALUE;
double tempValue;
Itemset itemset;
if ((allItemsets.getAttribute(attributeIndex).isContinuous()) &&
(numSubsets > 1)) {
Enumeration enum2 = allItemsets.enumerateItemsets();
while (enum2.hasMoreElements()) {
itemset = (Itemset) enum2.nextElement();
if (!itemset.isMissing(attributeIndex)) {
tempValue = itemset.getValue(attributeIndex);
if (tempValue > newCutPoint && tempValue <= cutPoint) {
newCutPoint = tempValue;
}
}
}
cutPoint = newCutPoint;
}
}
/** Function to cut the dataset in subsets.
*
* @param data The dataset to cut.
*
* @return All the datasets created.
*
* @throws Exception If the dataset cannot be cut.
*/
public final Dataset[] cutDataset(Dataset data) throws Exception {
Dataset[] itemsets = new Dataset[numSubsets];
double[] weights;
double newWeight;
Itemset itemset;
int subset, i, j;
for (j = 0; j < numSubsets; j++) {
itemsets[j] = new Dataset((Dataset) data, data.numItemsets());
}
for (i = 0; i < data.numItemsets(); i++) {
itemset = ((Dataset) data).itemset(i);
weights = weights(itemset);
subset = whichSubset(itemset);
if (subset > -1) {
itemsets[subset].addItemset(itemset);
} else {
for (j = 0; j < numSubsets; j++) {
if (weights[j] > 0) {
newWeight = weights[j] * itemset.getWeight();
itemsets[j].addItemset(itemset);
itemsets[j].lastItemset().setWeight(newWeight);
}
}
}
}
for (j = 0; j < numSubsets; j++) {
((Vector) itemsets[j].itemsets).trimToSize();
}
return itemsets;
}
/** Function to reset the classification of the model.
*
* @param data The new dataset used.
*
* @throws Exception If the classification cannot be reset.
*/
public void resetClassification(Dataset data) throws Exception {
if (numSubsets == 1) {
classification = new Classification(data, this);
} else {
Dataset insts = new Dataset(data, data.numItemsets());
for (int i = 0; i < data.numItemsets(); i++) {
if (whichSubset(data.itemset(i)) > -1) {
insts.addItemset(data.itemset(i));
}
}
Classification newD = new Classification(insts, this);
newD.addWithUnknownValue(data, attributeIndex);
classification = newD;
}
}
/** Returns weights if itemset is assigned to more than one subset, null otherwise.
*
* @param itemset The itemset.
*/
public final double[] weights(Itemset itemset) {
if (numSubsets == 1) {
return null;
} else {
double[] weights;
int i;
if (itemset.isMissing(attributeIndex)) {
weights = new double[numSubsets];
for (i = 0; i < numSubsets; i++) {
weights[i] = classification.perValue(i) /
classification.getTotal();
}
return weights;
} else {
return null;
}
}
}
/** Returns index of subset itemset is assigned to.
*
* @param itemset The itemset.
*/
public final int whichSubset(Itemset itemset) {
if (numSubsets == 1) {
return 0;
} else {
if (itemset.isMissing(attributeIndex)) {
return -1;
} else {
if (itemset.getAttribute(attributeIndex).isDiscret()) {
return (int) itemset.getValue(attributeIndex);
} else if (itemset.getValue(attributeIndex) <= cutPoint) {
return 0;
} else {
return 1;
}
}
}
}
/** Function to check if generated model is valid.
*
* @return True if the model is valid. False otherwise.
*/
public final boolean checkModel() {
if (numSubsets > 0) {
return true;
} else {
return false;
}
}
/** Returns the classification created by the model.
*
*/
public final Classification classification() {
return classification;
}
/** Returns the number of created subsets for the cut.
*
*/
public final int numSubsets() {
return numSubsets;
}
/** Function to compute the gain ratio.
*
* @param values The classification used to compute the gain ratio.
* @param totalnoInst Number of itemsets.
* @param numerator The information gain.
*
* @return The gain ratio for the classification.
*/
public final double gainRatioCutCrit(Classification values,
double totalnoInst, double numerator) {
double denumerator, noUnknown, unknownRate;
int i;
// Compute cut info.
denumerator = cutEntropy(values, totalnoInst);
// Test if cut is trivial.
if (denumerator == 0) {
return 0;
}
denumerator = denumerator / totalnoInst;
return numerator / denumerator;
}
/** Function to compute the information gain.
*
* @param values The classification used to compute the information gain.
* @param totalnoInst Number of itemsets.
* @param oldEnt The value for the entropy before cutting.
*
* @return The information gain.
*/
public final double infoGainCutCrit(Classification values,
double totalNoInst, double oldEnt) {
double numerator, noUnknown, unknownRate;
int i;
noUnknown = totalNoInst - values.getTotal();
unknownRate = noUnknown / totalNoInst;
numerator = (oldEnt - newEntropy(values));
numerator = (1 - unknownRate) * numerator;
// Cuts with no gain are useless.
if (numerator == 0) {
return 0;
}
return numerator / values.getTotal();
}
/** Function to compute the cut entropy.
*
* @param values The classification used to compute the entropy.
* @param totalnoInst Number of itemsets.
*
* @return The entropy of the cut.
*/
private final double cutEntropy(Classification values, double totalnoInst) {
double returnValue = 0, noUnknown;
int i;
noUnknown = totalnoInst - values.getTotal();
if (values.getTotal() > 0) {
for (i = 0; i < values.numValues(); i++) {
returnValue = returnValue - logFunc(values.perValue(i));
}
returnValue = returnValue - logFunc(noUnknown);
returnValue = returnValue + logFunc(totalnoInst);
}
return returnValue;
}
/** Function to compute entropy of classification before cutting.
*
* @param values The classification used to compute the entropy before cutting.
*
* @return The entropy for the classification before cutting.
*/
public final double oldEntropy(Classification values) {
double returnValue = 0;
int j;
for (j = 0; j < values.numClasses(); j++) {
returnValue = returnValue + logFunc(values.perClass(j));
}
return logFunc(values.getTotal()) - returnValue;
}
/** Function to compute entropy of classification after cutting.
*
* @param values The classification used to compute the entropy after cutting.
*
* @return The entropy for the classification after cutting.
*/
public final double newEntropy(Classification values) {
double returnValue = 0;
int i, j;
for (i = 0; i < values.numValues(); i++) {
for (j = 0; j < values.numClasses(); j++) {
returnValue = returnValue +
logFunc(values.perClassPerValue(i, j));
}
returnValue = returnValue - logFunc(values.perValue(i));
}
return -returnValue;
}
/** Returns the log2
*
* @param num The number to compute the log2.
*/
protected final double logFunc(double num) {
// Constant hard coded for efficiency reasons
if (num < 1e-6) {
return 0;
} else {
return num * Math.log(num) / Math.log(2);
}
}
/** Returns information gain for the generated cut.
*
*/
public final double getInfoGain() {
return infoGain;
}
/** Returns the gain ratio for the cut.
*
*/
public final double getGainRatio() {
return gainRatio;
}
/** Function to print left side of condition.
*
* @param data The dataset.
*
* @return The name of the attribute used in the cut.
*/
public final String leftSide(Dataset data) {
if (numSubsets == 1) {
return "";
} else {
return data.getAttribute(attributeIndex).name();
}
}
/** Function to print the condition satisfied by itemsets in a subset.
*
* @param index The index of the value.
* @param data The dataset.
*
* @return The value for the attribute of the cut.
*/
public final String rightSide(int index, Dataset data) {
if (numSubsets == 1) {
return "";
} else {
StringBuffer text;
text = new StringBuffer();
if (data.getAttribute(attributeIndex).isDiscret()) {
text.append(" = " +
data.getAttribute(attributeIndex).value(index));
} else if (index == 0) {
text.append(" <= " + doubleToString(cutPoint, 6));
} else {
text.append(" > " + doubleToString(cutPoint, 6));
}
return text.toString();
}
}
/** Function to print label for subset index of itemsets.
*
* @param index The index of the subset.
* @param data The dataset.
*
* @return The label created.
*/
public final String label(int index, Dataset data) {
StringBuffer text;
text = new StringBuffer();
text.append(((Dataset) data).getClassAttribute().
value(classification.maxClass(index)));
return text.toString();
}
/** Returns the index of the attribute to cut on.
*
*/
public final int attributeIndex() {
return attributeIndex;
}
/** Function to round a double and converts it into String.
*
* @param value The value to print.
* @param afterDecimalPoint Number of decimals positions.
*
* @return The value with the given number of decimals.
*/
public static String doubleToString(double value, int afterDecimalPoint) {
StringBuffer stringBuffer;
double temp;
int i, dotPosition;
long precisionValue;
temp = value * Math.pow(10.0, afterDecimalPoint);
if (Math.abs(temp) < Long.MAX_VALUE) {
precisionValue = (temp > 0) ? (long) (temp + 0.5) :
-(long) (Math.abs(temp) + 0.5);
if (precisionValue == 0) {
stringBuffer = new StringBuffer(String.valueOf(0));
} else {
stringBuffer = new StringBuffer(String.valueOf(precisionValue));
}
if (afterDecimalPoint == 0) {
return stringBuffer.toString();
}
dotPosition = stringBuffer.length() - afterDecimalPoint;
while (((precisionValue < 0) && (dotPosition < 1)) ||
(dotPosition < 0)) {
if (precisionValue < 0) {
stringBuffer.insert(1, 0);
} else {
stringBuffer.insert(0, 0);
}
dotPosition++;
}
stringBuffer.insert(dotPosition, '.');
if ((precisionValue < 0) && (stringBuffer.charAt(1) == '.')) {
stringBuffer.insert(1, 0);
} else if (stringBuffer.charAt(0) == '.') {
stringBuffer.insert(0, 0);
}
int currentPos = stringBuffer.length() - 1;
if (stringBuffer.charAt(currentPos) == '.') {
stringBuffer.setCharAt(currentPos, ' ');
}
return stringBuffer.toString().trim();
}
return new String("" + value);
}
/** Function to round a double and converts it into String.
*
* @param value The value to print.
* @param width The width that must have the string generated.
* @param afterDecimalPoint Number of decimals positions.
*
* @return The value with the given number of decimals.
*/
public static String doubleToString(double value, int width,
int afterDecimalPoint) {
String tempString = doubleToString(value, afterDecimalPoint);
char[] result;
int dotPosition;
// Protects sci notation
if ((afterDecimalPoint >= width) || (tempString.indexOf('E') != -1)) {
return tempString;
}
// Initialize result
result = new char[width];
for (int i = 0; i < result.length; i++) {
result[i] = ' ';
}
if (afterDecimalPoint > 0) {
// Get position of decimal point and insert decimal point
dotPosition = tempString.indexOf('.');
if (dotPosition == -1) {
dotPosition = tempString.length();
} else {
result[width - afterDecimalPoint - 1] = '.';
}
} else {
dotPosition = tempString.length();
}
int offset = width - afterDecimalPoint - dotPosition;
if (afterDecimalPoint > 0) {
offset--;
}
// Not enough room to decimal align within the supplied width
if (offset < 0) {
return tempString;
}
// Copy characters before decimal point
for (int i = 0; i < dotPosition; i++) {
result[offset + i] = tempString.charAt(i);
}
// Copy characters after decimal point
for (int i = dotPosition + 1; i < tempString.length(); i++) {
result[offset + i] = tempString.charAt(i);
}
return new String(result);
}
}