/*********************************************************************** This file is part of KEEL-software, the Data Mining tool for regression, classification, clustering, pattern mining and so on. Copyright (C) 2004-2010 F. Herrera (herrera@decsai.ugr.es) L. S�nchez (luciano@uniovi.es) J. Alcal�-Fdez (jalcala@decsai.ugr.es) S. Garc�a (sglopez@ujaen.es) A. Fern�ndez (alberto.fernandez@ujaen.es) J. Luengo (julianlm@decsai.ugr.es) This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see http://www.gnu.org/licenses/ **********************************************************************/ package keel.Algorithms.Decision_Trees.SLIQ; import java.util.*; /** Implementación en Java del algoritmo SLIQ Basada parcialmente en el código del algoritmo ID3 de Cristóbal Romero Morales (UCO) @author Francisco Charte Ojeda (práctica ICO de la UJA) @version 1.0 (28/12/09 - 10/1/10) */ /** * Clase que representa un nodo del árbol */ public class Node { /** Histograma asociado al nodo. El primer índice es el índice de clase y * el segundo es 0-izquierda ó 1-derecha, mientras que el valor indicaría * la frecuencia de esa clase en la rama indicada. La clase 0 está reservada * para conservar el total de cada rama. */ private int[][] histograma; /** indice Gini de este nodo */ private double indiceGini; /** Mejor ganancia para partir */ private double mejorGini; /** En los nodos interiores, referencias a los nodos hijo. */ private Node[] children; /** Clase asociada al nodo si es un nodo hoja*/ private int primeraClase; /** Indica si el nodo es una hoja o no */ private boolean esHoja; /** El conjunto de datos asociados al nodo. */ private Vector<ListaAtributos>[] data; /** El padre de este nodo. En la raíz parent == null. */ private Node parent; /** Valor (atributos continuos) con el mejor corte o índice del subconjunto (atributos discretos) con el mejor corte */ private double mejorValor; /** En los nodos hoja, el atributo que se utiliza para dividir el conjunto de datos. */ private int mejorAtributo; /** Coste del nodo (para la fase de poda) */ private int coste = -1; /** Número de clases existentes */ private int numClases; /** Crea un nuevo nodo. * */ public Node(int nClases) { // Inicializar los indicadores indiceGini = 1; mejorGini = 0; mejorAtributo = primeraClase = -1; esHoja = true; // Nodos hijo nulos children = new Node[2]; children[0] = children[1] = null; // Inicializar también el histograma asociado al nodo histograma = new int[nClases + 1][2]; for (int indice = 0; indice <= nClases; indice++) { histograma[indice][0] = histograma[indice][1] = 0; } // Conservar el número de clases numClases = nClases; parent = null; } /** Agregar un elemento al nodo * * @param clase Clase a la que pertenece el elemento */ public void agregaElemento(int clase) { // Si es el primer elemento agregado se toma su clase como principal if (primeraClase == -1) { primeraClase = clase; } // Si se agregan elementos de una clase distinta a la principal else if (primeraClase != clase) { esHoja = false; // el nodo no puede considerarse una hoja } // Contabilizar el nuevo dato en la clase que le corresponda histograma[clase + 1][0]++; histograma[0][0]++; // y en el total } /** Método que divide el nodo actual en dos que se agregan como hijos * */ public void divide() { children[0] = new Node(histograma.length); children[1] = new Node(histograma.length); // Actualizar los punteros al padre children[0].parent = this; children[1].parent = this; } /** Registra un elemento de la clase indicada que pasa de la hoja izquierda a la derecha * * @param clase indice de la clase del elemento */ protected void actualizaHistograma(int clase) { histograma[clase + 1][0]--; histograma[clase + 1][1]++; histograma[0][0]--; histograma[0][1]++; } /** Método que actualiza la clase principal del nodo contando la frecuencia * de las clases. Se usa después de podar un nodo */ public void actualizaClasePrincipal() { int frecuenciaClase = 0; // Se recorren las clases for (int indice = 1; indice <= numClases; indice++) { // quedándose siempre con la clase más representativa if (histograma[indice][0] + histograma[indice][1] > frecuenciaClase) { frecuenciaClase = histograma[indice][0] + histograma[indice][1]; primeraClase = indice - 1; } } } /** Método que prueba un corte y calcula la mejora que se obtendría. Para atributos discretos * * @param indAtributo indice del atributo * @param listaClases Lista de clases * @param atributo Referencia al atributo */ public void pruebaCorte(int indAtributo, ListaClases[] listaClases, Attribute atributo) { // Número máximo de valores a comprobar de manera exhaustiva, según la // descripción del algoritmo SLIQ de Mehta final int MAXSETSIZE = 10; int numValores = atributo.numValues(), // Número de valores distintos que puede tomar el atributo numClases = listaClases.length; // Número de clases a las que pueden pertenecer // Matriz de ocurrencias por valor y clase int[][] ocurrencias = new int[numClases][numValores]; int totalOcurrencias = 0; // Se inicializa todo a 0 for (int clase = 0; clase < numClases; clase++) { for (int valor = 0; valor < numValores; valor++) { ocurrencias[clase][valor] = 0; } } // Se recorre la lista de valores del nodo for (int indice = 0; indice < data[indAtributo].size(); indice++) { // Y se incrementa en la matriz de ocurrencias el elemento que corresponda int clase = listaClases[data[indAtributo].get(indice).indice].clase; int valor = (int) data[indAtributo].get(indice).valor; ocurrencias[clase][valor]++; totalOcurrencias++; } // --- Proceso para obtener el subconjunto con el mejor Gini --- double giniActual, giniSubconjunto = 1; int mejorSubconjunto = 0; // Ciclos para recorrer todas las combinaciones posibles int ciclos = (int) Math.pow(2, numValores) - 1; // Si no se supera el umbral, pueden probarse todas las combinaciones posibles if (atributo.numValues() <= MAXSETSIZE) { for (int indice = 0; indice < ciclos; indice++) { // Se obtiene el índice Gini para este subconjunto giniActual = calculaGini(indice, ocurrencias, numValores, numClases, totalOcurrencias); // Si es mejor que el mejor encontrado hasta ahora if (giniActual < giniSubconjunto) { mejorSubconjunto = indice; // Se guarda el índice del subconjunto giniSubconjunto = giniActual; // y el nuevo Gini } } } else { // Hay demasiados valores, usar algoritmo greedy mejorSubconjunto = 0; ciclos++; boolean mejorado; do { mejorado = false; // En cada ciclo se asume que no hay mejora for (int indice = 1; indice < ciclos; indice *= 2) { // se comprueba el subconjunto if ((mejorSubconjunto & indice) == 0) { giniActual = calculaGini(mejorSubconjunto + indice, ocurrencias, numValores, numClases, totalOcurrencias); if (giniActual < giniSubconjunto) { // Si hay mejora mejorSubconjunto += indice; giniSubconjunto = giniActual; mejorado = true; } } } } while (mejorado); // Mientras se mejore } // Anotar el mejor corte posible para este atributo indiceGini = giniSubconjunto; mejorAtributo = indAtributo; // Se almacena como valor el índice del mejor subconjunto encontrado mejorValor = mejorSubconjunto; } /** Método que prueba un corte y calcula la mejora que se obtendría. Para atributos continuos * * @param atributo indice del atributo * @param listaClases Lista de clases * @param valor Valor a comprobar * @param siguiente Valor siguiente */ public void pruebaCorte(int atributo, ListaClases[] listaClases, double valor, double siguiente) { // Calcular el valor intermedio entre valor y el siguiente (la lista está ordenada) double valorMedio = valor + (siguiente - valor) / 2; // Se guarda el histograma actual del nodo int[][] copiaHistograma = histograma.clone(); // Creo los dos nodos en los que se dividiría la lista de datos Node nodoI = new Node(histograma.length), nodoD = new Node(histograma.length); // Se recorre la lista de valores del atributo indicado for (int indice = 0; indice < data[atributo].size(); indice++) // y se agrega la distribución en el nodo que corresponda { if (data[atributo].get(indice).valor <= valorMedio) { nodoI.agregaElemento(listaClases[data[atributo].get(indice).indice].clase); } else { // Si el nodo cambia a la rama derecha nodoD.agregaElemento(listaClases[data[atributo].get(indice).indice].clase); // hay que actualizar también el histograma de este nodo actualizaHistograma(listaClases[data[atributo].get(indice).indice].clase); } } // Calcular el índice Gini indiceGini = calculaGini(); // Proporción de entradas en cada nodo double propIzq = nodoI.histograma[0][0] / (nodoI.histograma[0][0] + nodoD.histograma[0][0]), propDcho = nodoD.histograma[0][0] / (nodoI.histograma[0][0] + nodoD.histograma[0][0]); // Cálculo de la ganancia que se obtendría double GiniGain = indiceGini - nodoI.calculaGini() * propIzq - nodoD.calculaGini() * propDcho; // Si el GiniGain es mejor que mejorGini, guardarlo if (GiniGain > mejorGini) { mejorGini = GiniGain; mejorValor = valorMedio; // guardar los datos mejorAtributo = atributo; } // Recuperar el histograma original del nodo, para realizar correctamente // pruebas de cortes posteriores histograma = copiaHistograma; } /** Método encargado de calcular el índice Gini del nodo para atributos continuos * */ public double calculaGini() { // Tomar los totales double totalIzquierdo = histograma[0][0], totalDerecho = histograma[0][1]; double total = totalIzquierdo + totalDerecho; double probIzquierdo = 0, probDerecho = 0, prob = 0; // Si todos los datos están en una rama if (totalIzquierdo == 0 || totalDerecho == 0) { return 1; // no hay nada que calcular } // Acumular las probabilidades for (int indice = 1; indice < histograma.length; indice++) { prob = histograma[indice][0] / totalIzquierdo; probIzquierdo += prob * prob; prob = histograma[indice][1] / totalDerecho; probDerecho += prob * prob; } // Y calcular el índice a devolver return (totalIzquierdo / total) * (1 - probIzquierdo) + (totalDerecho / total) * (1 - probDerecho); } /** Método encargado de calcular el índice Gini para atributos discretos. * Está basado parcialmente en la implementación de la clase count_matrix * de la tésis de Nathan Rountree titulada 'Initialising Neural Networks * with Prior Knowledge', en la que hay un capítulo dedicado específicamente * al estudio de árboles, las técnicas de splitting y de poda. * * @param indSubconjunto indice del subconjunto a probar * @param ocurrencias Matriz de ocurrencias * @param numValores Número de valores en la matriz * @param numClases Número de clases en la matriz * @param totalOcurrencias Total de ocurrencias */ public double calculaGini(int indSubconjunto, int[][] ocurrencias, int numValores, int numClases, int totalOcurrencias) { int indice = 0, ciclos = numValores * numClases, tmpDerecha = 0, totalDerecha = 0, tmpIzquierda = 0, totalIzquierda = 0; double giniDerecha = 1, giniIzquierda = 1, peso, resultado; int[] subconjunto = new int[ciclos]; // Inicialización a cero de contadores for (int ind = 0; ind < ciclos; ind++) { subconjunto[ind] = 0; } // Contabilizar los datos que quedarían en el nodo izquierdo while (indSubconjunto > 0) { if (indSubconjunto % 2 != 0) { // Se dejan los valores impares en este subconjunto subconjunto[indice] = 1; for (int ind = 0; ind < numClases; ind++) { totalIzquierda += ocurrencias[ind][indice]; } } indSubconjunto /= 2; // Se va dividiendo por 2 indice++; } // Y en el nodo derecho totalDerecha = totalOcurrencias - totalIzquierda; // Acumular las distribuciones de los datos según las clases for (int i = 0; i < numClases; i++) { for (int j = 0; j < numValores; j++) { if (subconjunto[j] == 1) { tmpIzquierda += ocurrencias[i][j]; } else { tmpDerecha += ocurrencias[i][j]; } } peso = (double) tmpIzquierda / (double) totalIzquierda; peso *= peso; giniIzquierda -= peso; peso = (double) tmpDerecha / (double) totalDerecha; peso *= peso; giniDerecha -= peso; tmpIzquierda = tmpDerecha = 0; } // Calcular el índice Gini resultado = (totalIzquierda * giniIzquierda + totalDerecha * giniDerecha) / totalOcurrencias; return resultado; } /** Método que facilita el índice Gini asociado al índice * */ public double getIndiceGini() { return indiceGini; } /** Método para establecer los conjuntos de elementos que satisfacen la condición del nodo. * * @param newData Los conjuntos de elementos. */ public void setData(Vector<ListaAtributos>[] newData) { // Se guardan los datos data = newData; } /** Indica si el nodo es hoja * * @return true si el nodo es una hoja */ public boolean esHoja() { return this.esHoja; } /** Establece la condición de hoja de un nodo * * @param b true si el nodo es hoja */ public void setHoja(boolean b) { esHoja = b; } /** Devuelve los conjuntos de elementos que satisfacen la condición del nodo. */ public Vector<ListaAtributos>[] getData() { return data; } /** Facilita la clase más representativa del nodo * * @return indice de la clase */ public int getClase() { return primeraClase; } /** Devuelve el índice del atributo usado para descomponer el nodo. * */ public int getDecompositionAttribute() { return mejorAtributo; } /** Devuelve el valor usado para descomponer el nodo. * */ public double getDecompositionValue() { return mejorValor; } /** Método para establecer los hijos de un nodo. * * @param nodes Hijos del nodo. */ public void setChildren(Node[] nodes) { children = nodes; } /** M�todo para a�adir un hijo al nodo. * * @param node Nuevo hijo. */ public void addChildren(Node node) { children[numChildren()] = node; } /** Devuelve el número de hijos del nodo. * */ public int numChildren() { int nChildren = 0; for (int i = 0; i < children.length; i++) { if (children[i] != null) { nChildren++; } } return nChildren; } /** Devuelve los hijos del nodo. * */ public Node[] getChildren() { return children; } /** Devuelve el hijo correspondiente a un índice. * * @param index �ndice del hijo. */ public Node getChildren(int index) { return children[index]; } /** Método para establecer el nodo padre. * * @param node El padre del nodo. */ public void setParent(Node node) { parent = node; } /** Devuelve el padre del nodo. * */ public Node getParent() { return parent; } /** Devuelve el coste asociado al nodo * * @return El coste */ public int getCoste() { if (coste == -1) { calculaCoste(1); } return coste; } /** Método para calcular el coste de tener un nodo en el árbol * * @param fase Indica si se está en la fase de poda 1 o en la 2 */ public void calculaCoste(int fase) { coste = fase; // El coste es 1 para la primera fase y 2 para la segunda coste++; // Sumar el coste de la prueba del corte if (children[0] != null) // Si hay un hijo a la izquierda sumar su coste { coste += children[0].getCoste(); } if (children[1] != null) // Lo mismo si hay un hijo a la derecha { coste += children[1].getCoste(); } // Si éste es un nodo hoja o se está en la segunda fase de la poda if (esHoja() || fase == 2) // agregar también el coste del error { for (int indice = 1; indice <= numClases; indice++) { coste += histograma[indice][0] == primeraClase ? 0 : 1; } } } /** Método que calcula el coste del error al incorporar un nodo hijo * * @param hijo Hijo cuyos datos se incorporarían al padre * @return Coste del error */ public int costeError(Node hijo) { int suma = 0; // Sumar aquellos elementos cuya clase no coincida con la primeraClase // del nodo padre al que se incorporarán los datos for (int indice = 1; indice <= numClases; indice++) { if(indice != primeraClase) suma += hijo.histograma[indice][0]; //suma += hijo.histograma[indice][0] == primeraClase ? 0 : 1; } return suma; } /** Establece el coste del nodo * * @param coste Coste */ public void setCoste(int coste) { this.coste = coste; } }