/* * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. */ /* * CorrelationAttributeEval.java * Copyright (C) 1999-2012 University of Waikato, Hamilton, New Zealand * */ package weka.attributeSelection; import java.util.ArrayList; import java.util.Arrays; import java.util.Enumeration; import java.util.List; import java.util.Vector; import weka.core.Capabilities; import weka.core.Capabilities.Capability; import weka.core.Instance; import weka.core.Instances; import weka.core.Option; import weka.core.OptionHandler; import weka.core.RevisionUtils; import weka.core.Utils; import weka.filters.Filter; import weka.filters.unsupervised.attribute.ReplaceMissingValues; /** <!-- globalinfo-start --> * CorrelationAttributeEval :<br/> * <br/> * Evaluates the worth of an attribute by measuring the correlation (Pearson's) between it and the class.<br/> * <br/> * Nominal attributes are considered on a value by value basis by treating each value as an indicator. An overall correlation for a nominal attribute is arrived at via a weighted average.<br/> * <p/> <!-- globalinfo-end --> * <!-- options-start --> * Valid options are: <p/> * * <pre> -D * Output detailed info for nominal attributes</pre> * <!-- options-end --> * * @author Mark Hall (mhall{[at]}pentaho{[dot]}com) * @version $Revision: 9049 $ */ public class CorrelationAttributeEval extends ASEvaluation implements AttributeEvaluator, OptionHandler { /** For serialization */ private static final long serialVersionUID = -4931946995055872438L; /** The correlation for each attribute */ protected double[] m_correlations; /** Whether to output detailed (per value) correlation for nominal attributes */ protected boolean m_detailedOutput = false; /** Holds the detailed output info */ protected StringBuffer m_detailedOutputBuff; /** * Returns a string describing this attribute evaluator * * @return a description of the evaluator suitable for displaying in the * explorer/experimenter gui */ public String globalInfo() { return "CorrelationAttributeEval :\n\nEvaluates the worth of an attribute " + "by measuring the correlation (Pearson's) between it and the class.\n\n" + "Nominal attributes are considered on a value by " + "value basis by treating each value as an indicator. An overall " + "correlation for a nominal attribute is arrived at via a weighted average.\n"; } /** * Returns the capabilities of this evaluator. * * @return the capabilities of this evaluator * @see Capabilities */ @Override public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); result.disableAll(); // attributes result.enable(Capability.NOMINAL_ATTRIBUTES); result.enable(Capability.NUMERIC_ATTRIBUTES); result.enable(Capability.DATE_ATTRIBUTES); result.enable(Capability.MISSING_VALUES); // class result.enable(Capability.NOMINAL_CLASS); result.enable(Capability.MISSING_CLASS_VALUES); return result; } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. **/ @Override public Enumeration listOptions() { // TODO Auto-generated method stub Vector<Option> newVector = new Vector<Option>(); newVector.addElement(new Option( "\tOutput detailed info for nominal attributes", "D", 0, "-D")); return newVector.elements(); } /** * Parses a given list of options. <p/> * <!-- options-start --> * Valid options are: <p/> * * <pre> -D * Output detailed info for nominal attributes</pre> * <!-- options-end --> * * @param options the list of options as an array of strings * @throws Exception if an option is not supported */ @Override public void setOptions(String[] options) throws Exception { setOutputDetailedInfo(Utils.getFlag('D', options)); } /** * Gets the current settings of WrapperSubsetEval. * * @return an array of strings suitable for passing to setOptions() */ @Override public String[] getOptions() { String[] options = new String[1]; if (getOutputDetailedInfo()) { options[0] = "-D"; } else { options[0] = ""; } return options; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String outputDetailedInfoTipText() { return "Output per value correlation for nominal attributes"; } /** * Set whether to output per-value correlation for nominal attributes * * @param d true if detailed (per-value) correlation is to be output for * nominal attributes */ public void setOutputDetailedInfo(boolean d) { m_detailedOutput = d; } /** * Get whether to output per-value correlation for nominal attributes * * @return true if detailed (per-value) correlation is to be output for * nominal attributes */ public boolean getOutputDetailedInfo() { return m_detailedOutput; } /** * Evaluates an individual attribute by measuring the correlation (Pearson's) * between it and the class. Nominal attributes are considered on a value by * value basis by treating each value as an indicator. An overall correlation * for a nominal attribute is arrived at via a weighted average. * * @param attribute the index of the attribute to be evaluated * @return the correlation * @throws Exception if the attribute could not be evaluated */ @Override public double evaluateAttribute(int attribute) throws Exception { return m_correlations[attribute]; } /** * Describe the attribute evaluator * * @return a description of the attribute evaluator as a String */ @Override public String toString() { StringBuffer buff = new StringBuffer(); if (m_correlations == null) { buff.append("Correlation attribute evaluator has not been built yet."); } else { buff.append("\tCorrelation Ranking Filter"); if (m_detailedOutput && m_detailedOutputBuff.length() > 0) { buff.append("\n\tDetailed output for nominal attributes"); buff.append(m_detailedOutputBuff); } } return buff.toString(); } /** * Initializes an information gain attribute evaluator. Replaces missing * values with means/modes; Deletes instances with missing class values. * * @param data set of instances serving as training data * @throws Exception if the evaluator has not been generated successfully */ @Override public void buildEvaluator(Instances data) throws Exception { data = new Instances(data); data.deleteWithMissingClass(); ReplaceMissingValues rmv = new ReplaceMissingValues(); rmv.setInputFormat(data); data = Filter.useFilter(data, rmv); int numClasses = data.classAttribute().numValues(); int classIndex = data.classIndex(); int numInstances = data.numInstances(); m_correlations = new double[data.numAttributes()]; /* * boolean hasNominals = false; boolean hasNumerics = false; */ List<Integer> numericIndexes = new ArrayList<Integer>(); List<Integer> nominalIndexes = new ArrayList<Integer>(); if (m_detailedOutput) { m_detailedOutputBuff = new StringBuffer(); } // TODO for instance weights (folded into computing weighted correlations) // add another dimension just before the last [2] (0 for 0/1 binary vector // and // 1 for corresponding instance weights for the 1's) double[][][] nomAtts = new double[data.numAttributes()][][]; for (int i = 0; i < data.numAttributes(); i++) { if (data.attribute(i).isNominal() && i != classIndex) { nomAtts[i] = new double[data.attribute(i).numValues()][data .numInstances()]; Arrays.fill(nomAtts[i][0], 1.0); // set zero index for this att to all // 1's nominalIndexes.add(i); } else if (data.attribute(i).isNumeric() && i != classIndex) { numericIndexes.add(i); } } // do the nominal attributes if (nominalIndexes.size() > 0) { for (int i = 0; i < data.numInstances(); i++) { Instance current = data.instance(i); for (int j = 0; j < current.numValues(); j++) { if (current.attribute(current.index(j)).isNominal() && current.index(j) != classIndex) { // Will need to check for zero in case this isn't a sparse // instance (unless we add 1 and subtract 1) nomAtts[current.index(j)][(int) current.valueSparse(j)][i] += 1; nomAtts[current.index(j)][0][i] -= 1; } } } } if (data.classAttribute().isNumeric()) { double[] classVals = data.attributeToDoubleArray(classIndex); // do the numeric attributes for (Integer i : numericIndexes) { double[] numAttVals = data.attributeToDoubleArray(i); m_correlations[i] = Utils.correlation(numAttVals, classVals, numAttVals.length); if (m_correlations[i] == 1.0) { // check for zero variance (useless numeric attribute) if (Utils.variance(numAttVals) == 0) { m_correlations[i] = 0; } } } // do the nominal attributes if (nominalIndexes.size() > 0) { // now compute the correlations for the binarized nominal attributes for (Integer i : nominalIndexes) { double sum = 0; double corr = 0; double sumCorr = 0; double sumForValue = 0; if (m_detailedOutput) { m_detailedOutputBuff.append("\n\n") .append(data.attribute(i).name()); } for (int j = 0; j < data.attribute(i).numValues(); j++) { sumForValue = Utils.sum(nomAtts[i][j]); corr = Utils .correlation(nomAtts[i][j], classVals, classVals.length); // useless attribute - all instances have the same value if (sumForValue == numInstances || sumForValue == 0) { corr = 0; } if (corr < 0.0) { corr = -corr; } sumCorr += sumForValue * corr; sum += sumForValue; if (m_detailedOutput) { m_detailedOutputBuff.append("\n\t") .append(data.attribute(i).value(j)).append(": "); m_detailedOutputBuff.append(Utils.doubleToString(corr, 6)); } } m_correlations[i] = (sum > 0) ? sumCorr / sum : 0; } } } else { // class is nominal // TODO extra dimension for storing instance weights too double[][] binarizedClasses = new double[data.classAttribute() .numValues()][data.numInstances()]; // this is equal to the number of instances for all inst weights = 1 double[] classValCounts = new double[data.classAttribute().numValues()]; for (int i = 0; i < data.numInstances(); i++) { Instance current = data.instance(i); binarizedClasses[(int) current.classValue()][i] = 1; } for (int i = 0; i < data.classAttribute().numValues(); i++) { classValCounts[i] = Utils.sum(binarizedClasses[i]); } double sumClass = Utils.sum(classValCounts); // do numeric attributes first if (numericIndexes.size() > 0) { for (Integer i : numericIndexes) { double[] numAttVals = data.attributeToDoubleArray(i); double corr = 0; double sumCorr = 0; for (int j = 0; j < data.classAttribute().numValues(); j++) { corr = Utils.correlation(numAttVals, binarizedClasses[j], numAttVals.length); if (corr < 0.0) { corr = -corr; } if (corr == 1.0) { // check for zero variance (useless numeric attribute) if (Utils.variance(numAttVals) == 0) { corr = 0; } } sumCorr += classValCounts[j] * corr; } m_correlations[i] = sumCorr / sumClass; } } if (nominalIndexes.size() > 0) { for (Integer i : nominalIndexes) { if (m_detailedOutput) { m_detailedOutputBuff.append("\n\n") .append(data.attribute(i).name()); } double sumForAtt = 0; double corrForAtt = 0; for (int j = 0; j < data.attribute(i).numValues(); j++) { double sumForValue = Utils.sum(nomAtts[i][j]); double corr = 0; double sumCorr = 0; double avgCorrForValue = 0; sumForAtt += sumForValue; for (int k = 0; k < numClasses; k++) { // corr between value j and class k corr = Utils.correlation(nomAtts[i][j], binarizedClasses[k], binarizedClasses[k].length); // useless attribute - all instances have the same value if (sumForValue == numInstances || sumForValue == 0) { corr = 0; } if (corr < 0.0) { corr = -corr; } sumCorr += classValCounts[k] * corr; } avgCorrForValue = sumCorr / sumClass; corrForAtt += sumForValue * avgCorrForValue; if (m_detailedOutput) { m_detailedOutputBuff.append("\n\t") .append(data.attribute(i).value(j)).append(": "); m_detailedOutputBuff.append(Utils.doubleToString(avgCorrForValue, 6)); } } // the weighted average corr for att i as // a whole (wighted by value frequencies) m_correlations[i] = (sumForAtt > 0) ? corrForAtt / sumForAtt : 0; } } } if (m_detailedOutputBuff != null && m_detailedOutputBuff.length() > 0) { m_detailedOutputBuff.append("\n"); } } /** * Returns the revision string. * * @return the revision */ public String getRevision() { return RevisionUtils.extract("$Revision: 9049 $"); } /** * Main method for testing this class. * * @param args the options */ public static void main(String[] args) { runEvaluator(new CorrelationAttributeEval(), args); } }