/* RocData.java created 2007-12-06
*
*/
package org.signalml.domain.roc;
import static org.signalml.app.util.i18n.SvarogI18n._;
import java.beans.IntrospectionException;
import java.io.IOException;
import java.io.Writer;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import org.signalml.app.model.components.LabelledPropertyDescriptor;
import org.signalml.app.model.components.PropertyProvider;
import org.signalml.app.model.components.WriterExportableTable;
import org.signalml.method.iterator.IterableParameter;
import org.signalml.method.iterator.MethodIteratorData;
import org.signalml.method.iterator.ParameterIterationSettings;
/** RocData
*
*
* @author Michal Dobaczewski © 2007-2008 CC Otwarte Systemy Komputerowe Sp. z o.o.
*/
public class RocData implements WriterExportableTable, PropertyProvider {
private IterableParameter[] parameters;
private ArrayList<RocDataPoint> rocDataPoints;
// statistics
private boolean dirtyStatistics = true;
private double areaUnderCurve;
private int maxAccuracyIteration;
private double maxAccuracy;
private int pointBelowOSSIteration;
private double pointBelowOSSDistance;
private int pointAboveOSSIteration;
private double pointAboveOSSDistance;
private double ossIntersectionTP;
public RocData(IterableParameter[] parameters) {
if (parameters == null) {
throw new NullPointerException("No parameters");
}
if (parameters.length == 0) {
throw new IllegalArgumentException("More than one parameter needed");
}
this.parameters = parameters;
this.rocDataPoints = new ArrayList<RocDataPoint>();
}
public RocData(IterableParameter[] parameters, RocDataPoint[] rocDataPoints) {
this(parameters);
this.rocDataPoints = new ArrayList<RocDataPoint>(this.rocDataPoints);
}
public static RocData createForParameterIterationSettings(ParameterIterationSettings[] parameters) {
LinkedList<IterableParameter> iteratedParameters = new LinkedList<IterableParameter>();
for (ParameterIterationSettings parameter : parameters) {
if (parameter.isIterated()) {
iteratedParameters.add(parameter.getParameter());
}
}
IterableParameter[] iteratedArr = new IterableParameter[iteratedParameters.size()];
iteratedParameters.toArray(iteratedArr);
return new RocData(iteratedArr);
}
public static RocData createForMethodIteratorData(MethodIteratorData data) {
return createForParameterIterationSettings(data.getParameters());
}
public void add(RocDataPoint rocDataPoint) {
rocDataPoints.add(rocDataPoint);
dirtyStatistics = true;
}
public int getParameterCount() {
return parameters.length;
}
public IterableParameter getParameterAt(int index) {
return parameters[index];
}
public int getSampleCount() {
return rocDataPoints.size();
}
public RocDataPoint getRocDataPointAt(int sample) {
return rocDataPoints.get(sample);
}
public Object getParameterValueAt(int index, int sample) {
return rocDataPoints.get(sample).getParameterValues()[index];
}
public int getTruePositiveCount(int sample) {
return rocDataPoints.get(sample).getTruePositiveCount();
}
public int getTrueNegativeCount(int sample) {
return rocDataPoints.get(sample).getTrueNegativeCount();
}
public int getFalsePositiveCount(int sample) {
return rocDataPoints.get(sample).getFalsePositiveCount();
}
public int getFalseNegativeCount(int sample) {
return rocDataPoints.get(sample).getFalseNegativeCount();
}
public double getTrueRateAt(int sample) {
return rocDataPoints.get(sample).getTrueRate();
}
public double getFalseRateAt(int sample) {
return rocDataPoints.get(sample).getFalseRate();
}
public double[] getTrueRates() {
int cnt = rocDataPoints.size();
double[] trueRates = new double[cnt];
for (int i=0; i<cnt; i++) {
trueRates[i] = rocDataPoints.get(i).getTrueRate();
}
return trueRates;
}
public double[] getFalseRates() {
int cnt = rocDataPoints.size();
double[] falseRates = new double[cnt];
for (int i=0; i<cnt; i++) {
falseRates[i] = rocDataPoints.get(i).getFalseRate();
}
return falseRates;
}
private boolean isAboveOSS(RocDataPoint point) {
return (point.getTrueRate() >= (1 - point.getFalseRate()));
}
private double getOSSDistance(RocDataPoint point) {
return (Math.abs(1 - (point.getFalseRate() + point.getTrueRate())) / Math.sqrt(2));
}
private void calculateStatistics() {
int cnt = rocDataPoints.size();
if (cnt == 0) {
areaUnderCurve = 0;
maxAccuracy = 0;
maxAccuracyIteration = -1;
pointAboveOSSDistance = -1;
pointAboveOSSIteration = -1;
pointBelowOSSDistance = -1;
pointBelowOSSIteration = -1;
} else {
areaUnderCurve = 0;
RocDataPoint thisPoint;
RocDataPoint prevPoint;
double accuracy;
double distance;
thisPoint = rocDataPoints.get(0);
areaUnderCurve += thisPoint.getFalseRate()*thisPoint.getTrueRate()/2;
maxAccuracy = thisPoint.getAccuracy();
maxAccuracyIteration = 0;
pointAboveOSSIteration = -1;
pointBelowOSSIteration = -1;
pointAboveOSSDistance = Double.MAX_VALUE;
pointBelowOSSDistance = Double.MAX_VALUE;
if (isAboveOSS(thisPoint)) {
pointAboveOSSDistance = getOSSDistance(thisPoint);
pointAboveOSSIteration = 0;
} else {
pointBelowOSSDistance = getOSSDistance(thisPoint);
pointBelowOSSIteration = 0;
}
for (int i=1; i<cnt; i++) {
prevPoint = thisPoint;
thisPoint = rocDataPoints.get(i);
areaUnderCurve += (0.5 * (prevPoint.getTrueRate() + thisPoint.getTrueRate()) * (thisPoint.getFalseRate() - prevPoint.getFalseRate()));
accuracy = thisPoint.getAccuracy();
if (accuracy > maxAccuracy) {
maxAccuracy = accuracy;
maxAccuracyIteration = i;
}
distance = getOSSDistance(thisPoint);
if (isAboveOSS(thisPoint)) {
if (distance < pointAboveOSSDistance) {
pointAboveOSSDistance = distance;
pointAboveOSSIteration = i;
}
} else {
if (distance < pointBelowOSSDistance) {
pointBelowOSSDistance = distance;
pointBelowOSSIteration = i;
}
}
}
areaUnderCurve += (0.5 * (thisPoint.getTrueRate() + 1) * (1 - thisPoint.getFalseRate()));
if (pointAboveOSSIteration >= 0 && pointBelowOSSIteration >= 0) {
thisPoint = rocDataPoints.get(pointAboveOSSIteration);
double xa = thisPoint.getFalseRate();
double ya = thisPoint.getTrueRate();
thisPoint = rocDataPoints.get(pointBelowOSSIteration);
double xb = thisPoint.getFalseRate();
double yb = thisPoint.getTrueRate();
ossIntersectionTP = (yb*(1-xa) - ya*(1-xb)) / ((xb+yb) - (xa+ya));
} else {
ossIntersectionTP = -1;
}
if (pointAboveOSSIteration < 0) {
pointAboveOSSDistance = -1;
} else {
pointAboveOSSIteration++;
}
if (pointBelowOSSIteration < 0) {
pointBelowOSSDistance = -1;
} else {
pointBelowOSSIteration++;
}
if (maxAccuracyIteration >= 0) {
maxAccuracyIteration++;
}
}
dirtyStatistics = false;
}
public double getAreaUnderCurve() {
if (dirtyStatistics) {
calculateStatistics();
}
return areaUnderCurve;
}
public double getMaxAccuracy() {
if (dirtyStatistics) {
calculateStatistics();
}
return maxAccuracy;
}
public int getMaxAccuracyIteration() {
if (dirtyStatistics) {
calculateStatistics();
}
return maxAccuracyIteration;
}
public int getPointBelowOSSIteration() {
if (dirtyStatistics) {
calculateStatistics();
}
return pointBelowOSSIteration;
}
public double getPointBelowOSSDistance() {
if (dirtyStatistics) {
calculateStatistics();
}
return pointBelowOSSDistance;
}
public int getPointAboveOSSIteration() {
if (dirtyStatistics) {
calculateStatistics();
}
return pointAboveOSSIteration;
}
public double getPointAboveOSSDistance() {
if (dirtyStatistics) {
calculateStatistics();
}
return pointAboveOSSDistance;
}
public double getOssIntersectionTP() {
if (dirtyStatistics) {
calculateStatistics();
}
return ossIntersectionTP;
}
@Override
public void export(Writer writer, String columnSeparator, String rowSeparator, Object userObject) throws IOException {
int i;
int e;
for (i=0; i<parameters.length; i++) {
writer.append(parameters[i].getName());
writer.append(columnSeparator);
}
writer.append("TP");
writer.append(columnSeparator);
writer.append("FP");
writer.append(columnSeparator);
writer.append("TN");
writer.append(columnSeparator);
writer.append("FN");
writer.append(columnSeparator);
writer.append("FP rate");
writer.append(columnSeparator);
writer.append("TP rate");
writer.append(columnSeparator);
writer.append("sensitivity");
writer.append(columnSeparator);
writer.append("specifity");
writer.append(columnSeparator);
writer.append("accuracy");
writer.append(columnSeparator);
writer.append("positive_pred_value");
writer.append(columnSeparator);
writer.append("negative_pred_value");
writer.append(columnSeparator);
writer.append("false_discovery_rate");
writer.append(rowSeparator);
int sampleCount = rocDataPoints.size();
RocDataPoint thisPoint;
for (e=0; e<sampleCount ; e++) {
thisPoint = rocDataPoints.get(e);
for (i=0; i<parameters.length; i++) {
writer.append(thisPoint.getParameterValues()[i].toString());
writer.append(columnSeparator);
}
writer.append(Integer.toString(thisPoint.getTruePositiveCount()));
writer.append(columnSeparator);
writer.append(Integer.toString(thisPoint.getFalsePositiveCount()));
writer.append(columnSeparator);
writer.append(Integer.toString(thisPoint.getTrueNegativeCount()));
writer.append(columnSeparator);
writer.append(Integer.toString(thisPoint.getFalseNegativeCount()));
writer.append(columnSeparator);
writer.append(Double.toString(thisPoint.getFalseRate()));
writer.append(columnSeparator);
writer.append(Double.toString(thisPoint.getTrueRate()));
writer.append(columnSeparator);
writer.append(Double.toString(thisPoint.getSensitivity()));
writer.append(columnSeparator);
writer.append(Double.toString(thisPoint.getSpecifity()));
writer.append(columnSeparator);
writer.append(Double.toString(thisPoint.getAccuracy()));
writer.append(columnSeparator);
writer.append(Double.toString(thisPoint.getPositivePredictiveValue()));
writer.append(columnSeparator);
writer.append(Double.toString(thisPoint.getNegativePredictiveValue()));
writer.append(columnSeparator);
writer.append(Double.toString(thisPoint.getFalseDiscoveryRate()));
writer.append(rowSeparator);
}
}
@Override
public List<LabelledPropertyDescriptor> getPropertyList() throws IntrospectionException {
LinkedList<LabelledPropertyDescriptor> list = new LinkedList<LabelledPropertyDescriptor>();
list.add(new LabelledPropertyDescriptor(_("area under curve"), "areaUnderCurve", RocData.class, "getAreaUnderCurve", null));
list.add(new LabelledPropertyDescriptor(_("maximal accuracy iteration"), "maxAccuracyIteration", RocData.class, "getMaxAccuracyIteration", null));
list.add(new LabelledPropertyDescriptor(_("maximal accuracy"), "maxAccuracy", RocData.class, "getMaxAccuracy", null));
list.add(new LabelledPropertyDescriptor(_("point abose OSS iteration"), "pointAboveOSSIteration", RocData.class, "getPointAboveOSSIteration", null));
list.add(new LabelledPropertyDescriptor(_("point above OSS distance"), "pointAboveOSSDistance", RocData.class, "getPointAboveOSSDistance", null));
list.add(new LabelledPropertyDescriptor(_("point below OSS iteration"), "pointBelowOSSIteration", RocData.class, "getPointBelowOSSIteration", null));
list.add(new LabelledPropertyDescriptor(_("point below OSS distance"), "pointBelowOSSDistance", RocData.class, "getPointBelowOSSDistance", null));
list.add(new LabelledPropertyDescriptor(_("OSS intersection TP"), "ossIntersectionTP", RocData.class, "getOssIntersectionTP", null));
return list;
}
}