package edu.hawaii.jmotif.performance.digits;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.TreeSet;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletionService;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import org.hackystat.utilities.stacktrace.StackTrace;
import cc.mallet.util.Randoms;
import edu.hawaii.jmotif.performance.UCRGenericClassifier;
import edu.hawaii.jmotif.performance.UCRLOOCVErrorFunction;
import edu.hawaii.jmotif.performance.UCRUtils;
import edu.hawaii.jmotif.sampler.DirectMethod;
import edu.hawaii.jmotif.sampler.ObjectiveFunction;
import edu.hawaii.jmotif.sampler.PrintConsumer;
import edu.hawaii.jmotif.sampler.Solver;
import edu.hawaii.jmotif.sampler.UCRSolver;
import edu.hawaii.jmotif.text.SAXCollectionStrategy;
/**
* Helper-runner for CBF test.
*
* @author psenin
*
*/
public class UCRDigitsDirectSampler extends UCRGenericClassifier {
// num of threads to use
//
private static final int THREADS_NUM = 3;
// data
//
private static final String TRAINING_DATA = "data/digits/digits_reduced_400.csv";
// output prefix
//
private static final String outputPrefix = "digits_direct";
// SAX parameters to use
//
private static final int WINDOW_MIN = 170;
private static final int WINDOW_MAX = 190;
private static final int PAA_MIN = 14;
private static final int PAA_MAX = 18;
private static final int ALPHABET_MIN = 2;
private static final int ALPHABET_MAX = 5;
private static final int HOLD_OUT_NUM = 2;
private static final int MAX_ITERATIONS = 3;
private static List<String> globalResults = new ArrayList<String>();
private UCRDigitsDirectSampler() {
super();
}
/**
* @param args
* @throws Exception
*/
@SuppressWarnings("unchecked")
public static void main(String[] args) throws Exception {
Map<String, List<double[]>> td = UCRUtils.readUCRData(TRAINING_DATA);
consoleLogger.fine("reading file: " + TRAINING_DATA);
consoleLogger.fine("trainData classes: " + td.size() + ", series length: "
+ td.entrySet().iterator().next().getValue().get(0).length);
for (Entry<String, List<double[]>> e : td.entrySet()) {
consoleLogger.fine(" training class: " + e.getKey() + " series: " + e.getValue().size());
}
consoleLogger.fine("re-sampling... ");
Map<String, List<double[]>> fullData = resample(td, 8);
consoleLogger.fine("trainData classes: " + fullData.size() + ", series length: "
+ fullData.entrySet().iterator().next().getValue().get(0).length);
for (Entry<String, List<double[]>> e : fullData.entrySet()) {
consoleLogger.fine(" training class: " + e.getKey() + " series: " + e.getValue().size());
}
for (int i = 1; i < 10; i++) {
globalResults = new ArrayList<String>();
String currentClass = String.valueOf(i);
consoleLogger.fine(" separating class " + currentClass);
Map<String, List<double[]>> trainData = new HashMap<String, List<double[]>>();
trainData.put(currentClass, makeACopy(fullData.get(currentClass)));
List<double[]> other = new ArrayList<double[]>();
for (int k = 1; k < 10; k++) {
if (!(currentClass.equalsIgnoreCase(String.valueOf(k)))) {
other.addAll(resampleSubset(fullData.get(String.valueOf(k)), 3));
}
}
trainData.put("other", other);
consoleLogger.fine("After separation: trainData classes: " + trainData.size()
+ ", series length: " + trainData.entrySet().iterator().next().getValue().get(0).length);
for (Entry<String, List<double[]>> e : trainData.entrySet()) {
consoleLogger.fine(" training class: " + e.getKey() + " series: " + e.getValue().size());
}
double[] parametersLowest = { Double.valueOf(WINDOW_MIN), Double.valueOf(PAA_MIN),
Double.valueOf(ALPHABET_MIN) };
double[] parametersHighest = { Double.valueOf(WINDOW_MAX), Double.valueOf(PAA_MAX),
Double.valueOf(ALPHABET_MAX) };
ExecutorService executorService = Executors.newFixedThreadPool(THREADS_NUM);
CompletionService<List<String>> completionService = new ExecutorCompletionService<List<String>>(
executorService);
int totalTaskCounter = 0;
// create and submit the job for NOREDUCTION
//
ObjectiveFunction noredFunction = new UCRLOOCVErrorFunction();
noredFunction.setStrategy(SAXCollectionStrategy.NOREDUCTION);
PrintConsumer noredConsumer = new PrintConsumer(SAXCollectionStrategy.NOREDUCTION);
noredFunction.setUpperBounds(parametersHighest);
noredFunction.setLowerBounds(parametersLowest);
noredFunction.setData(trainData, HOLD_OUT_NUM);
DirectMethod noredMethod = new DirectMethod();
noredMethod.addConsumer(noredConsumer);
Solver noredSolver = new UCRSolver(MAX_ITERATIONS);
noredSolver.init(noredFunction, noredMethod);
completionService.submit((Callable<List<String>>) noredSolver);
totalTaskCounter++;
// create and submit the job for EXACT
//
ObjectiveFunction exactFunction = new UCRLOOCVErrorFunction();
exactFunction.setStrategy(SAXCollectionStrategy.EXACT);
PrintConsumer exactConsumer = new PrintConsumer(SAXCollectionStrategy.EXACT);
exactFunction.setUpperBounds(parametersHighest);
exactFunction.setLowerBounds(parametersLowest);
exactFunction.setData(trainData, HOLD_OUT_NUM);
DirectMethod exactMethod = new DirectMethod();
exactMethod.addConsumer(exactConsumer);
Solver exactSolver = new UCRSolver(MAX_ITERATIONS);
exactSolver.init(exactFunction, exactMethod);
completionService.submit((Callable<List<String>>) exactSolver);
totalTaskCounter++;
// create and submit the job for CLASSIC
//
ObjectiveFunction classicFunction = new UCRLOOCVErrorFunction();
classicFunction.setStrategy(SAXCollectionStrategy.CLASSIC);
PrintConsumer classicConsumer = new PrintConsumer(SAXCollectionStrategy.CLASSIC);
classicFunction.setUpperBounds(parametersHighest);
classicFunction.setLowerBounds(parametersLowest);
classicFunction.setData(trainData, HOLD_OUT_NUM);
DirectMethod classicMethod = new DirectMethod();
classicMethod.addConsumer(classicConsumer);
Solver classicSolver = new UCRSolver(MAX_ITERATIONS);
classicSolver.init(classicFunction, classicMethod);
completionService.submit((Callable<List<String>>) classicSolver);
totalTaskCounter++;
// waiting for completion, shutdown pool disabling new tasks from being submitted
executorService.shutdown();
consoleLogger.info("Submitted " + totalTaskCounter + " jobs, shutting down the pool");
// waiting for jobs to finish
//
//
try {
while (totalTaskCounter > 0) {
//
// poll with a wait up to FOUR hours
Future<List<String>> finished = completionService.poll(128, TimeUnit.HOURS);
if (null == finished) {
// something went wrong - break from here
System.err.println("Breaking POLL loop after 128 HOURS of waiting...");
break;
}
else {
List<String> res = finished.get();
globalResults.addAll(res);
totalTaskCounter--;
}
}
consoleLogger.info("All jobs completed.");
}
catch (Exception e) {
System.err.println("Error while waiting results: " + StackTrace.toString(e));
}
finally {
// wait at least 1 more hour before terminate and fail
try {
if (!executorService.awaitTermination(1, TimeUnit.HOURS)) {
executorService.shutdownNow(); // Cancel currently executing tasks
if (!executorService.awaitTermination(30, TimeUnit.MINUTES))
System.err.println("Pool did not terminate... FATAL ERROR");
}
}
catch (InterruptedException ie) {
System.err.println("Error while waiting interrupting: " + StackTrace.toString(ie));
// (Re-)Cancel if current thread also interrupted
executorService.shutdownNow();
// Preserve interrupt status
Thread.currentThread().interrupt();
}
}
Collections.sort(globalResults, new Comparator<String>() {
@Override
public int compare(String arg0, String arg1) {
String[] line1 = arg0.split(",");
String[] line2 = arg0.split(",");
return Double.valueOf(line1[line1.length - 1]).compareTo(
Double.valueOf(line2[line2.length - 1]));
}
});
System.out.println("Best result: " + globalResults.get(globalResults.size()-1));
String paramsLine = globalResults.get(globalResults.size()-1);
String[] split = paramsLine.split(",");
String strategyKey = split[0];
}
BufferedWriter bw = new BufferedWriter(new FileWriter(outputPrefix + ".csv"));
for (String line : globalResults) {
bw.write(line + CR);
}
bw.close();
}
private static Collection<? extends double[]> resampleSubset(List<double[]> list, int resampleSize) {
Randoms rnd = new Randoms();
List<double[]> res = new ArrayList<double[]>();
TreeSet<Integer> seen = new TreeSet<Integer>();
for (int i = 0; i < resampleSize; i++) {
Integer idx = 0;
do {
idx = (int) Math.floor(rnd.nextUniform(0., (double) list.size()));
}
while (seen.contains(idx));
res.add(Arrays.copyOf(list.get(idx), list.get(idx).length));
seen.add(idx);
}
return res;
}
private static List<double[]> makeACopy(List<double[]> list) {
List<double[]> res = new ArrayList<double[]>();
for (double[] s : list) {
res.add(Arrays.copyOf(s, s.length));
}
return res;
}
private static Map<String, List<double[]>> resample(Map<String, List<double[]>> td,
int resampleSize) {
Randoms rnd = new Randoms();
Map<String, List<double[]>> res = new HashMap<String, List<double[]>>();
for (Entry<String, List<double[]>> e : td.entrySet()) {
List<double[]> list = e.getValue();
List<double[]> entry = new ArrayList<double[]>();
TreeSet<Integer> seen = new TreeSet<Integer>();
for (int i = 0; i < resampleSize; i++) {
Integer idx = 0;
do {
idx = (int) Math.floor(rnd.nextUniform(0., (double) list.size()));
}
while (seen.contains(idx));
entry.add(Arrays.copyOf(list.get(idx), list.get(idx).length));
seen.add(idx);
}
res.put(e.getKey(), entry);
}
return res;
}
}