/*
* RapidMiner
*
* Copyright (C) 2001-2008 by Rapid-I and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapid-i.com
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.operator.performance;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.tools.math.Averagable;
/**
* Computes the empirical corelation coefficient 'r' between label and
* prediction. For
* <code>P=prediction, L=label, V=Variance, Cov=Covariance</code> we calculate
* r by: <br>
* <code>Cov(L,P) / sqrt(V(L)*V(P))</code>.
*
* Implementation hint: this implementation intensionally recomputes the mean
* and variance of prediction and label despite the fact that they are available
* by the Attribute objects. The reason: it can happen, that there are some
* examples which have a NaN as prediction or label, but not both. In this case,
* mean and variance stored in tie Attributes and computed here can differ.
*
* @author Robert Rudolph, Ingo Mierswa
* @version $Id: CorrelationCriterion.java,v 2.8 2006/03/21 15:35:50 ingomierswa
* Exp $
*/
public class CorrelationCriterion extends MeasuredPerformance {
private static final long serialVersionUID = -8789903466296509903L;
private Attribute labelAttribute;
private Attribute predictedLabelAttribute;
private Attribute weightAttribute;
private double exampleCount = 0;
private double sumLabel;
private double sumPredict;
private double sumLabelPredict;
private double sumLabelSqr;
private double sumPredictSqr;
public CorrelationCriterion() {
}
public CorrelationCriterion(CorrelationCriterion sc) {
super(sc);
this.sumLabelPredict = sc.sumLabelPredict;
this.sumLabelSqr = sc.sumLabelSqr;
this.sumPredictSqr = sc.sumPredictSqr;
this.sumLabel = sc.sumLabel;
this.sumPredict = sc.sumPredict;
this.exampleCount = sc.exampleCount;
this.labelAttribute = (Attribute)sc.labelAttribute.clone();
this.predictedLabelAttribute = (Attribute)sc.predictedLabelAttribute.clone();
if (sc.weightAttribute != null)
this.weightAttribute = (Attribute)sc.weightAttribute.clone();
}
public double getExampleCount() {
return exampleCount;
}
/** Returns the maximum fitness of 1.0. */
public double getMaxFitness() {
return 1.0d;
}
/** Updates all sums needed to compute the correlation coefficient. */
public void countExample(Example example) {
double label = example.getValue(labelAttribute);
double plabel = example.getValue(predictedLabelAttribute);
if (labelAttribute.isNominal()) {
String predLabelString = predictedLabelAttribute.getMapping().mapIndex((int)plabel);
plabel = labelAttribute.getMapping().getIndex(predLabelString);
}
double weight = 1.0d;
if (weightAttribute != null)
weight = example.getValue(weightAttribute);
double prod = label * plabel * weight;
if (!Double.isNaN(prod)) {
sumLabelPredict += prod;
sumLabel += label;
sumLabelSqr += label * label;
sumPredict += plabel;
sumPredictSqr += plabel * plabel;
exampleCount += weight;
}
}
public String getDescription() {
return "Returns the correlation coefficient between the label and predicted label.";
}
public double getMikroAverage() {
double r = (exampleCount * sumLabelPredict - sumLabel * sumPredict) / (Math.sqrt((exampleCount * sumLabelSqr - sumLabel * sumLabel) * (exampleCount * sumPredictSqr - sumPredict * sumPredict)));
return r;
}
public double getMikroVariance() {
return Double.NaN;
}
public void startCounting(ExampleSet eset, boolean useExampleWeights) throws OperatorException {
super.startCounting(eset, useExampleWeights);
exampleCount = 0;
sumLabelPredict = sumLabel = sumPredict = sumLabelSqr = sumPredictSqr = 0.0d;
this.labelAttribute = eset.getAttributes().getLabel();
this.predictedLabelAttribute = eset.getAttributes().getPredictedLabel();
if (useExampleWeights)
this.weightAttribute = eset.getAttributes().getWeight();
}
public void buildSingleAverage(Averagable performance) {
CorrelationCriterion other = (CorrelationCriterion) performance;
this.sumLabelPredict += other.sumLabelPredict;
this.sumLabelSqr += other.sumLabelSqr;
this.sumPredictSqr += other.sumPredictSqr;
this.sumLabel += other.sumLabel;
this.sumPredict += other.sumPredict;
this.exampleCount += other.exampleCount;
}
public double getFitness() {
return getAverage();
}
public String getName() {
return "correlation";
}
}