/*
* ARX: Powerful Data Anonymization
* Copyright 2012 - 2017 Fabian Prasser, Florian Kohlmayer and contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.deidentifier.arx.aggregates;
import java.text.ParseException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import org.deidentifier.arx.ARXLogisticRegressionConfiguration;
import org.deidentifier.arx.DataHandleInternal;
import org.deidentifier.arx.aggregates.classification.ClassificationDataSpecification;
import org.deidentifier.arx.aggregates.classification.ClassificationMethod;
import org.deidentifier.arx.aggregates.classification.ClassificationResult;
import org.deidentifier.arx.aggregates.classification.MultiClassLogisticRegression;
import org.deidentifier.arx.aggregates.classification.MultiClassZeroR;
import org.deidentifier.arx.common.WrappedBoolean;
import org.deidentifier.arx.common.WrappedInteger;
import org.deidentifier.arx.exceptions.ComputationInterruptedException;
/**
* Statistics representing the prediction accuracy of a data mining
* classification operator
*
* @author Fabian Prasser
*/
public class StatisticsClassification {
/**
* A matrix mapping confidence thresholds to precision and recall
*
* @author Fabian Prasser
*
*/
public static class PrecisionRecallMatrix {
/** Confidence thresholds*/
private static final double[] CONFIDENCE_THRESHOLDS = new double[]{
0d, 0.1d, 0.2d, 0.3d, 0.4d, 0.5d, 0.6d, 0.7d, 0.8d, 0.9d, 1d
};
/** Measurements */
private double measurements = 0d;
/** Precision */
private final double[] precision = new double[CONFIDENCE_THRESHOLDS.length];
/** Recall */
private final double[] recall = new double[CONFIDENCE_THRESHOLDS.length];
/**
* @return the confidence thresholds
*/
public double[] getConfidenceThresholds() {
return CONFIDENCE_THRESHOLDS;
}
/**
* @return the precision
*/
public double[] getPrecision() {
return precision;
}
/**
* @return the recall
*/
public double[] getRecall() {
return recall;
}
/**
* Adds a new value
* @param confidence
* @param correct
*/
void add(double confidence, boolean correct) {
for (int i = 0; i < CONFIDENCE_THRESHOLDS.length; i++) {
if (confidence >= CONFIDENCE_THRESHOLDS[i]) {
recall[i]++;
precision[i] += correct ? 1d : 0d;
}
}
measurements++;
}
/**
* Packs the results
*/
void pack() {
// Pack
for (int i = 0; i < CONFIDENCE_THRESHOLDS.length; i++) {
if (recall[i] == 0d) {
precision[i] = 1d;
} else {
precision[i] /= recall[i];
recall[i] /= measurements;
}
}
}
}
/** Accuracy */
private double accuracy;
/** Average error */
private double averageError;
/** Interrupt flag */
private final WrappedBoolean interrupt;
/** Interrupt flag */
private final WrappedInteger progress;
/** Precision/recall matrix */
private PrecisionRecallMatrix matrix = new PrecisionRecallMatrix();
/** Num classes */
private int numClasses;
/** Original accuracy */
private double originalAccuracy;
/** Original accuracy */
private double originalAverageError;
/** Precision/recall matrix */
private PrecisionRecallMatrix originalMatrix = new PrecisionRecallMatrix();
/** Random */
private final Random random;
/** ZeroR accuracy */
private double zeroRAccuracy;
/** ZeroR accuracy */
private double zeroRAverageError;
/** Measurements */
private int numMeasurements;
/**
* Creates a new set of statistics for the given classification task
* @param inputHandle - The input features handle
* @param outputHandle - The output features handle
* @param features - The feature attributes
* @param clazz - The class attributes
* @param config - The configuration
* @param interrupt - The interrupt flag
* @param progress
* @throws ParseException
*/
StatisticsClassification(DataHandleInternal inputHandle,
DataHandleInternal outputHandle,
String[] features,
String clazz,
ARXLogisticRegressionConfiguration config,
WrappedBoolean interrupt,
WrappedInteger progress) throws ParseException {
// Init
this.interrupt = interrupt;
this.progress = progress;
// Check and clean up
double samplingFraction = (double)config.getMaxRecords() / (double)inputHandle.getNumRows();
if (samplingFraction <= 0d) {
throw new IllegalArgumentException("Sampling fraction must be >0");
}
if (samplingFraction > 1d) {
samplingFraction = 1d;
}
// Initialize random
if (!config.isDeterministic()) {
this.random = new Random();
} else {
this.random = new Random(config.getSeed());
}
// TODO: Feature is not used. Continuous variables are treated as categorical.
ClassificationDataSpecification specification = new ClassificationDataSpecification(inputHandle,
outputHandle,
features,
clazz,
interrupt);
// Train and evaluate
int k = inputHandle.getNumRows() > config.getNumFolds() ? config.getNumFolds() : inputHandle.getNumRows();
List<List<Integer>> folds = getFolds(inputHandle.getNumRows(), k);
// Track
int classifications = 0;
double total = 100d / ((double)inputHandle.getNumRows() * (double)folds.size());
double done = 0d;
// For each fold as a validation set
for (int evaluationFold = 0; evaluationFold < folds.size(); evaluationFold++) {
// Create classifiers
ClassificationMethod inputLR = new MultiClassLogisticRegression(specification, config);
ClassificationMethod inputZR = new MultiClassZeroR(specification);
ClassificationMethod outputLR = null;
if (inputHandle != outputHandle) {
outputLR = new MultiClassLogisticRegression(specification, config);
}
// Try
try {
// Train with all training sets
boolean trained = false;
for (int trainingFold = 0; trainingFold < folds.size(); trainingFold++) {
if (trainingFold != evaluationFold) {
for (int index : folds.get(trainingFold)) {
checkInterrupt();
inputLR.train(inputHandle, outputHandle, index);
inputZR.train(inputHandle, outputHandle, index);
if (outputLR != null && !outputHandle.isOutlier(index)) {
outputLR.train(outputHandle, outputHandle, index);
}
trained = true;
this.progress.value = (int)((++done) * total);
}
}
}
// Close
inputLR.close();
inputZR.close();
if (outputLR != null) {
outputLR.close();
}
// Now validate
for (int index : folds.get(evaluationFold)) {
// Check
checkInterrupt();
// If trained
if (trained) {
// Classify
ClassificationResult resultInputLR = inputLR.classify(inputHandle, index);
ClassificationResult resultInputZR = inputZR.classify(inputHandle, index);
ClassificationResult resultOutputLR = outputLR == null ? null : outputLR.classify(outputHandle, index);
classifications++;
// Correct result
String actualValue = outputHandle.getValue(index, specification.classIndex, true);
// Maintain data about inputZR
this.zeroRAverageError += resultInputZR.error(actualValue);
this.zeroRAccuracy += resultInputZR.correct(actualValue) ? 1d : 0d;
// Maintain data about inputLR
boolean correct = resultInputLR.correct(actualValue);
this.originalAverageError += resultInputLR.error(actualValue);
this.originalAccuracy += correct ? 1d : 0d;
this.originalMatrix.add(resultInputLR.confidence(), correct);
// Maintain data about outputLR
if (resultOutputLR != null) {
correct = resultOutputLR.correct(actualValue);
this.averageError += resultOutputLR.error(actualValue);
this.accuracy += correct ? 1d : 0d;
this.matrix.add(resultOutputLR.confidence(), correct);
}
}
this.progress.value = (int)((++done) * total);
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
// Maintain data about inputZR
this.zeroRAverageError /= (double)classifications;
this.zeroRAccuracy/= (double)classifications;
// Maintain data about inputLR
this.originalAverageError /= (double)classifications;
this.originalAccuracy /= (double)classifications;
this.originalMatrix.pack();
// Maintain data about outputLR
if (inputHandle != outputHandle) {
this.averageError /= (double)classifications;
this.accuracy /= (double)classifications;
this.matrix.pack();
} else {
this.averageError = this.originalAverageError;
this.accuracy = this.originalAccuracy;
this.matrix = this.originalMatrix;
}
this.numClasses = specification.classMap.size();
this.numMeasurements = classifications;
}
/**
* Returns the resulting accuracy. Obtained by training a
* Logistic Regression classifier on the output (or input) dataset.
*
* @return
*/
public double getAccuracy() {
return this.accuracy;
}
/**
* Returns the average error, defined as avg(1d-probability-of-correct-result) for
* each classification event.
*
* @return
*/
public double getAverageError() {
return this.averageError;
}
/**
* Returns the number of classes
* @return
*/
public int getNumClasses() {
return this.numClasses;
}
/**
* Returns the number of measurements
* @return
*/
public int getNumMeasurements() {
return this.numMeasurements;
}
/**
* Returns the maximal accuracy. Obtained by training a
* Logistic Regression classifier on the input dataset.
*
* @return
*/
public double getOriginalAccuracy() {
return this.originalAccuracy;
}
/**
* Returns the average error, defined as avg(1d-probability-of-correct-result) for
* each classification event.
*
* @return
*/
public double getOriginalAverageError() {
return this.originalAverageError;
}
/**
* Returns a precision/recall matrix for LogisticRegression on input
* @return
*/
public PrecisionRecallMatrix getOriginalPrecisionRecall() {
return this.originalMatrix;
}
/**
* Returns a precision/recall matrix
* @return
*/
public PrecisionRecallMatrix getPrecisionRecall() {
return this.matrix;
}
/**
* Returns the minimal accuracy. Obtained by training a
* ZeroR classifier on the input dataset.
*
* @return
*/
public double getZeroRAccuracy() {
return this.zeroRAccuracy;
}
/**
* Returns the average error, defined as avg(1d-probability-of-correct-result) for
* each classification event.
*
* @return
*/
public double getZeroRAverageError() {
return this.zeroRAverageError;
}
@Override
public String toString() {
StringBuilder builder = new StringBuilder();
builder.append("StatisticsClassification{\n");
builder.append(" - Accuracy:\n");
builder.append(" * Original: ").append(originalAccuracy).append("\n");
builder.append(" * ZeroR: ").append(zeroRAccuracy).append("\n");
builder.append(" * Output: ").append(accuracy).append("\n");
builder.append(" - Average error:\n");
builder.append(" * Original: ").append(originalAverageError).append("\n");
builder.append(" * ZeroR: ").append(zeroRAverageError).append("\n");
builder.append(" * Output: ").append(averageError).append("\n");
builder.append(" - Number of classes: ").append(numClasses).append("\n");
builder.append(" - Number of measurements: ").append(numMeasurements).append("\n");
builder.append("}");
return builder.toString();
}
/**
* Checks whether an interruption happened.
*/
private void checkInterrupt() {
if (interrupt.value) {
throw new ComputationInterruptedException("Interrupted");
}
}
/**
* Creates the folds
* @param length
* @param k
* @param random
* @return
*/
private List<List<Integer>> getFolds(int length, int k) {
// Prepare indexes
List<Integer> rows = new ArrayList<>();
for (int row = 0; row < length; row++) {
rows.add(row);
}
Collections.shuffle(rows, random);
// Create folds
List<List<Integer>> folds = new ArrayList<>();
int size = rows.size() / k;
size = size > 1 ? size : 1;
for (int i = 0; i < k; i++) {
// For each fold
int min = i * size;
int max = (i + 1) * size;
if (i == k - 1) {
max = rows.size();
}
// Collect rows
List<Integer> fold = new ArrayList<>();
for (int j = min; j < max; j++) {
fold.add(rows.get(j));
}
// Store
folds.add(fold);
}
// Free
rows.clear();
rows = null;
return folds;
}
}