package edu.hawaii.jmotif.experiment.cbf;
import java.io.IOException;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.logging.ConsoleHandler;
import java.util.logging.Formatter;
import java.util.logging.Handler;
import java.util.logging.Logger;
import org.hackystat.utilities.logger.HackystatLogger;
import edu.hawaii.jmotif.distance.EuclideanDistance;
import edu.hawaii.jmotif.timeseries.TSException;
import edu.hawaii.jmotif.util.BriefFormatter;
/**
* Helper-runner for CBF test.
*
* @author psenin
*
*/
public class CBFKNNEuclideanClassifier {
// various variables
// classifier test parameters
//
/** The timeseries length. */
private static final int SERIES_LENGTH = 128;
/** The test set size. */
private static final int TEST_SAMPLE_SIZE = 300;
private static Logger consoleLogger;
private static String LOGGING_LEVEL = "FINE";
// uncoment that to remove early abandoning warning
//
// private static String LOGGING_LEVEL = "FINE";
static {
consoleLogger = HackystatLogger.getLogger("debug.console", "preseries");
consoleLogger.setUseParentHandlers(false);
for (Handler handler : consoleLogger.getHandlers()) {
consoleLogger.removeHandler(handler);
}
ConsoleHandler handler = new ConsoleHandler();
Formatter formatter = new BriefFormatter();
handler.setFormatter(formatter);
consoleLogger.addHandler(handler);
HackystatLogger.setLoggingLevel(consoleLogger, LOGGING_LEVEL);
}
/**
* @param args
* @throws TSException
* @throws IndexOutOfBoundsException
* @throws IOException
*/
public static void main(String[] args) throws IndexOutOfBoundsException, TSException, IOException {
// making training and test collections
Map<String, List<double[]>> trainData = new HashMap<String, List<double[]>>();
Map<String, List<double[]>> testData = new HashMap<String, List<double[]>>();
// ticks - i.e. time
int[] t = new int[SERIES_LENGTH];
for (int i = 0; i < SERIES_LENGTH; i++) {
t[i] = i;
}
int TRAINING_SET_SIZE = Integer.valueOf(args[0]);
// cylinder sample
List<double[]> cylinders = new ArrayList<double[]>();
for (int i = 0; i < TRAINING_SET_SIZE + TEST_SAMPLE_SIZE; i++) {
cylinders.add(CBFGenerator.cylinder(t));
}
trainData.put("0", extract(cylinders, 0, TRAINING_SET_SIZE));
testData.put("0", extract(cylinders, TRAINING_SET_SIZE, TRAINING_SET_SIZE + TEST_SAMPLE_SIZE));
// bell sample
List<double[]> bells = new ArrayList<double[]>();
for (int i = 0; i < TRAINING_SET_SIZE + TEST_SAMPLE_SIZE; i++) {
bells.add(CBFGenerator.bell(t));
}
trainData.put("1", extract(bells, 0, TRAINING_SET_SIZE));
testData.put("1", extract(bells, TRAINING_SET_SIZE, TRAINING_SET_SIZE + TEST_SAMPLE_SIZE));
// funnel sample
List<double[]> funnels = new ArrayList<double[]>();
for (int i = 0; i < TRAINING_SET_SIZE + TEST_SAMPLE_SIZE; i++) {
funnels.add(CBFGenerator.funnel(t));
}
trainData.put("2", extract(funnels, 0, TRAINING_SET_SIZE));
testData.put("2", extract(funnels, TRAINING_SET_SIZE, TRAINING_SET_SIZE + TEST_SAMPLE_SIZE));
// ################ begin classification
//
int totalPositiveTests = 0;
int totalTestSample = TEST_SAMPLE_SIZE * 3;
int queryCounter = 0;
NumberFormat df = new DecimalFormat("0.00");
// #### here we iterate over all TEST series, class by class, series by series
//
for (Entry<String, List<double[]>> querySet : testData.entrySet()) {
for (double[] querySeries : querySet.getValue()) {
consoleLogger.fine("classifying query " + queryCounter + " of class " + querySet.getKey());
// this holds the closest neighbor out of all training data with its class
//
double bestDistance = Double.MAX_VALUE;
String bestClass = "";
// #### here we iterate over all TRAIN series, class by class, series by series
//
for (Entry<String, List<double[]>> referenceSet : trainData.entrySet()) {
for (double[] referenceSeries : referenceSet.getValue()) {
// this computes the Euclidean distance.
// earlyAbandonedDistance implementation abandons full distance computation
// if current value is above the best known
//
Double distance = EuclideanDistance.earlyAbandonedDistance(querySeries,
referenceSeries, bestDistance);
// Double distance = EuclideanDistance.earlyAbandonedDistance(
// TSUtils.zNormalize(querySeries), TSUtils.zNormalize(referenceSeries), bestDistance);
if (null != distance && distance.doubleValue() < bestDistance) {
bestDistance = distance.doubleValue();
bestClass = referenceSet.getKey();
consoleLogger.fine(" + closest class: " + bestClass + " distance: " + bestDistance);
}
else {
// consoleLogger.fine(" - abandoned search for class: " + referenceSet.getKey()
// + ", distance: " + EuclideanDistance.distance(querySeries, referenceSeries));
// consoleLogger.fine(" - abandoned search for class: "
// + referenceSet.getKey()
// + ", distance: "
// + EuclideanDistance.distance(TSUtils.zNormalize(querySeries),
// TSUtils.zNormalize(referenceSeries)));
}
}
}
if (bestClass.equalsIgnoreCase(querySet.getKey())) {
totalPositiveTests++;
}
queryCounter++;
}
}
double accuracy = (double) totalPositiveTests / (double) totalTestSample;
double error = 1.0d - accuracy;
System.out.println(accuracy + "," + error + "\n");
}
private static List<double[]> extract(List<double[]> cylinders, int start, int end) {
List<double[]> res = new ArrayList<double[]>();
for (int i = start; i < end; i++) {
res.add(cylinders.get(i));
}
return res;
}
}