package ca.pfv.spmf.algorithms.classifiers.decisiontree.id3; /* This file is copyright (c) 2008-2012 Philippe Fournier-Viger * * This file is part of the SPMF DATA MINING SOFTWARE * (http://www.philippe-fournier-viger.com/spmf). * * SPMF is free software: you can redistribute it and/or modify it under the * terms of the GNU General Public License as published by the Free Software * Foundation, either version 3 of the License, or (at your option) any later * version. * * SPMF is distributed in the hope that it will be useful, but WITHOUT ANY * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR * A PARTICULAR PURPOSE. See the GNU General Public License for more details. * You should have received a copy of the GNU General Public License along with * SPMF. If not, see <http://www.gnu.org/licenses/>. */ import java.io.BufferedReader; import java.io.FileReader; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Set; /** * This is an implementation of the ID3 algorithm for creating a decision tree. * <br/><br/> * ID3 is a very popular algorithms described in many artificial intelligence * and data mining textbooks. * * @author Philippe Fournier-Viger */ public class AlgoID3 { // the list of attributes private String[] allAttributes; // the position of the target attribute in the list of attributes private int indexTargetAttribute = -1; // the set of values for the target attribute private Set<String> targetAttributeValues = new HashSet<String>(); // for statistics private long startTime; // start time of the latest execution private long endTime; // end time of the latest execution /** * Create a decision tree from a set of training instances. * @param input path to an input file containing training instances * @param targetAttribute the target attribute (that will be used for classification) * @param separator the separator in the input file (e.g. space). * @return a decision tree * @throws IOException exception if error reading the file */ public DecisionTree runAlgorithm(String input, String targetAttribute, String separator) throws IOException { // record the start time startTime = System.currentTimeMillis(); // create an empty decision tree DecisionTree tree = new DecisionTree(); // (1) read input file BufferedReader reader = new BufferedReader(new FileReader(input)); String line = reader.readLine(); // Read the first line and note the name of the attributes. // At the same time identify the position of the target attribute and // other attributes. allAttributes = line.split(separator); // make an array to store the attributes except the target attribute int[] remainingAttributes = new int[allAttributes.length - 1]; int pos = 0; // for each attribute for (int i = 0; i < allAttributes.length; i++) { // if it is the target attribute if (allAttributes[i].equals(targetAttribute)) { // save the position of the target attribute. It will be useful // later. indexTargetAttribute = i; } else { // otherwise add the attribute to the array of attributes remainingAttributes[pos++] = i; } } // Read instances into memory (line by line until end of file) List<String[]> instances = new ArrayList<String[]>(); while (((line = reader.readLine()) != null)) { // if the line is a comment, is empty or is a // kind of metadata if (line.isEmpty() == true || line.charAt(0) == '#' || line.charAt(0) == '%' || line.charAt(0) == '@') { continue; } // split the line String[] lineSplit = line.split(separator); // process the instance instances.add(lineSplit); // remember the value for the target attribute targetAttributeValues.add(lineSplit[indexTargetAttribute]); } reader.close(); // close input file // (2) Start the recusive process // create the tree tree.root = id3(remainingAttributes, instances); tree.allAttributes = allAttributes; endTime = System.currentTimeMillis(); // record end time return tree; // return the tree } /** * Method to create a subtree according to a set of attributes and training * instances. * @param remainingAttributes remaining attributes to create the tree * @param instances a list of training instances * @return node of the subtree created */ private Node id3(int[] remainingAttributes, List<String[]> instances) { // if only one remaining attribute, // return a class node with the most common value in the instances if (remainingAttributes.length == 0) { // Count the frequency of class Map<String, Integer> targetValuesFrequency = calculateFrequencyOfAttributeValues( instances, indexTargetAttribute); // Loop over the values to find the class with the highest frequency int highestCount = 0; String highestName = ""; for (Entry<String, Integer> entry : targetValuesFrequency .entrySet()) { // if the frequency is higher if (entry.getValue() > highestCount) { highestCount = entry.getValue(); highestName = entry.getKey(); } } // return a class node with the value having the highest frequency ClassNode classNode = new ClassNode(); classNode.className = highestName; return classNode; } // Calculate the frequency of each target attribute value and // at the same time check if there is a single class. Map<String, Integer> targetValuesFrequency = calculateFrequencyOfAttributeValues( instances, indexTargetAttribute); // if all instances are from the same class if (targetValuesFrequency.entrySet().size() == 1) { ClassNode classNode = new ClassNode(); classNode.className = (String) targetValuesFrequency.keySet() .toArray()[0]; return classNode; } // Calculate global entropy double globalEntropy = 0d; // for each value for (String value : targetAttributeValues) { // calculate frequency Integer frequencyInt = targetValuesFrequency.get(value); // if the frequency is not zero if(frequencyInt != null) { // calculate the frequency has a double double frequencyDouble = frequencyInt / (double) instances.size(); // update the global entropy globalEntropy -= frequencyDouble * Math.log(frequencyDouble) / Math.log(2); } } // System.out.println("Global entropy = " + globalEntropy); // Select the attribute from remaining attributes such that if we split // the dataset on this // attribute, we will get the higher information gain int attributeWithHighestGain = 0; double highestGain = -99999; for (int attribute : remainingAttributes) { double gain = calculateGain(attribute, instances, globalEntropy); // System.out.println("Process " + allAttributes[attribute] + // " gain = " + gain); if (gain >= highestGain) { highestGain = gain; attributeWithHighestGain = attribute; } } // if the highest gain is 0.... if (highestGain == 0) { ClassNode classNode = new ClassNode(); // take the most frequent classes int topFrequency = 0; String className = null; for(Entry<String, Integer> entry: targetValuesFrequency.entrySet()) { if(entry.getValue() > topFrequency) { topFrequency = entry.getValue(); className = entry.getKey(); } } classNode.className = className; return classNode; } // Create a decision node for the attribute // System.out.println("Attribute with highest gain = " + // allAttributes[attributeWithHighestGain] + " " + highestGain); DecisionNode decisionNode = new DecisionNode(); decisionNode.attribute = attributeWithHighestGain; // calculate the list of remaining attribute after we remove the // attribute int[] newRemainingAttribute = new int[remainingAttributes.length - 1]; int pos = 0; for (int i = 0; i < remainingAttributes.length; i++) { if (remainingAttributes[i] != attributeWithHighestGain) { newRemainingAttribute[pos++] = remainingAttributes[i]; } } // Split the dataset into partitions according to the selected attribute Map<String, List<String[]>> partitions = new HashMap<String, List<String[]>>(); for (String[] instance : instances) { String value = instance[attributeWithHighestGain]; List<String[]> listInstances = partitions.get(value); if (listInstances == null) { listInstances = new ArrayList<String[]>(); partitions.put(value, listInstances); } listInstances.add(instance); } // Create the values for the subnodes decisionNode.nodes = new Node[partitions.size()]; decisionNode.attributeValues = new String[partitions.size()]; // For each partition, make a recursive call to create // the corresponding branches in the tree. int index = 0; for (Entry<String, List<String[]>> partition : partitions.entrySet()) { decisionNode.attributeValues[index] = partition.getKey(); decisionNode.nodes[index] = id3(newRemainingAttribute, partition.getValue()); // recursive call index++; } // return the root node of the subtree created return decisionNode; } /** * Calculate the information gain of an attribute for a set of instance * @param attributePos the position of the attribute * @param instances a list of instances * @param globalEntropy the global entropy * @return the gain */ private double calculateGain(int attributePos, List<String[]> instances, double globalEntropy) { // Count the frequency of each value for the attribute Map<String, Integer> valuesFrequency = calculateFrequencyOfAttributeValues( instances, attributePos); // Calculate the gain double sum = 0; // for each value for (Entry<String, Integer> entry : valuesFrequency.entrySet()) { // make the sum sum += entry.getValue() / ((double) instances.size()) * calculateEntropyIfValue(instances, attributePos, entry.getKey()); } // subtract the sum from the global entropy return globalEntropy - sum; } /** * Calculate the entropy for the target attribute, if a given attribute has * a given value. * * @param instances * : list of instances * @param attributeIF * : the given attribute * @param valueIF * : the given value * @return entropy */ private double calculateEntropyIfValue(List<String[]> instances, int attributeIF, String valueIF) { // variable to count the number of instance having the value for that // attribute int instancesCount = 0; // variable to count the frequency of each value Map<String, Integer> valuesFrequency = new HashMap<String, Integer>(); // for each instance for (String[] instance : instances) { // if that instance has the value for the attribute if (instance[attributeIF].equals(valueIF)) { String targetValue = instance[indexTargetAttribute]; // increase the frequency if (valuesFrequency.get(targetValue) == null) { valuesFrequency.put(targetValue, 1); } else { valuesFrequency.put(targetValue, valuesFrequency.get(targetValue) + 1); } // increase the number of instance having the value for that // attribute instancesCount++; } } // calculate entropy double entropy = 0; // for each value of the target attribute for (String value : targetAttributeValues) { // get the frequency Integer count = valuesFrequency.get(value); // if the frequency is not null if (count != null) { // update entropy according to the formula double frequency = count / (double) instancesCount; entropy -= frequency * Math.log(frequency) / Math.log(2); } } return entropy; } /** * This method calculates the frequency of each value for an attribute in a * given set of instances * * @param instances * A set of instances * @param indexAttribute * The attribute. * @return A map where the keys are attributes and values are the number of * times that the value appeared in the set of instances. */ private Map<String, Integer> calculateFrequencyOfAttributeValues( List<String[]> instances, int indexAttribute) { // A map to calculate the frequency of each value: // Key: a string indicating a value // Value: the frequency Map<String, Integer> targetValuesFrequency = new HashMap<String, Integer>(); // for each instance of the training set for (String[] instance : instances) { // get the value of the attribute for that instance String targetValue = instance[indexAttribute]; // increase the frequency by 1 if (targetValuesFrequency.get(targetValue) == null) { targetValuesFrequency.put(targetValue, 1); } else { targetValuesFrequency.put(targetValue, targetValuesFrequency.get(targetValue) + 1); } } // return the map return targetValuesFrequency; } /** * Print statistics about the execution of this algorithm */ public void printStatistics() { System.out.println("Time to construct decision tree = " + (endTime - startTime) + " ms"); System.out.println("Target attribute = " + allAttributes[indexTargetAttribute]); System.out.print("Other attributes = "); for (String attribute : allAttributes) { if (!attribute.equals(allAttributes[indexTargetAttribute])) { System.out.print(attribute + " "); } } System.out.println(); } }