package edu.hawaii.jmotif.performance.digits;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletionService;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import org.hackystat.utilities.stacktrace.StackTrace;
import edu.hawaii.jmotif.performance.UCRGenericClassifier;
/**
* Helper-runner for test.
*
* @author psenin
*
*/
public class DigitsKNNDTW extends UCRGenericClassifier {
// data locations
//
private static final String TRAINING_DATA = "data/digits/train_centered.csv";
private static final String TEST_DATA = "data/digits/test_centered.csv";
private static final int threadsNum = 6;
/**
* Runnable.
*
* @throws Exception if error occurs.
*/
public static void main(String[] args) throws Exception {
// making training and test collections
//
Map<String, double[]> trainData = readTrainData(TRAINING_DATA);
List<double[]> testData = readTestData(TEST_DATA);
BufferedWriter bw = new BufferedWriter(new FileWriter(new File("data/digits/dtw_knn.csv")));
bw.write("ImageId,Label\n");
// create thread pool for processing these users
//
ExecutorService executorService = Executors.newFixedThreadPool(threadsNum);
CompletionService<String> completionService = new ExecutorCompletionService<String>(
executorService);
int totalTaskCounter = 0;
int seriesCounter = 1;
for (double[] series : testData) {
// create and submit the job
final DTWknnJob job = new DTWknnJob(series, seriesCounter, trainData);
completionService.submit(job);
totalTaskCounter++;
seriesCounter++;
}
// waiting for completion, shutdown pool disabling new tasks from being submitted
executorService.shutdown();
consoleLogger.info("Submitted " + totalTaskCounter + " jobs, shutting down the pool");
try {
while (totalTaskCounter > 0) {
//
// poll with a wait up to FOUR hours
Future<String> finished = completionService.poll(96, TimeUnit.HOURS);
if (null == finished) {
// something went wrong - break from here
System.err.println("Breaking POLL loop after 48 HOURS of waiting...");
break;
}
else {
String res = finished.get();
if (!(res.startsWith("ok_"))) {
System.err.println("Exception caught: " + finished.get());
break;
}
else {
consoleLogger.info(res);
bw.write(res + "\n");
}
totalTaskCounter--;
}
}
consoleLogger.info("All jobs completed.");
}
catch (Exception e) {
System.err.println("Error while waiting results: " + StackTrace.toString(e));
}
finally {
// wait at least 1 more hour before terminate and fail
try {
if (!executorService.awaitTermination(1, TimeUnit.HOURS)) {
executorService.shutdownNow(); // Cancel currently executing tasks
if (!executorService.awaitTermination(30, TimeUnit.MINUTES))
System.err.println("Pool did not terminate... FATAL ERROR");
}
}
catch (InterruptedException ie) {
System.err.println("Error while waiting interrupting: " + StackTrace.toString(ie));
// (Re-)Cancel if current thread also interrupted
executorService.shutdownNow();
// Preserve interrupt status
Thread.currentThread().interrupt();
}
}
bw.close();
}
private static Map<String, double[]> readTrainData(String fileName) throws NumberFormatException,
IOException {
Map<String, double[]> res = new HashMap<String, double[]>();
BufferedReader br = new BufferedReader(new FileReader(new File(fileName)));
String line = "";
int counter = 0;
while ((line = br.readLine()) != null) {
if (line.trim().length() == 0) {
continue;
}
String[] split = line.trim().split(",|\\s+");
String label = split[0];
double[] series = new double[split.length - 1];
for (int i = 1; i < split.length; i++) {
series[i - 1] = Double.valueOf(split[i].trim()).doubleValue();
}
res.put(label + "_" + String.valueOf(counter), series);
counter++;
}
br.close();
return res;
}
private static ArrayList<double[]> readTestData(String fileName) throws NumberFormatException,
IOException {
ArrayList<double[]> res = new ArrayList<double[]>();
BufferedReader br = new BufferedReader(new FileReader(new File(fileName)));
String line = "";
while ((line = br.readLine()) != null) {
if (line.trim().length() == 0) {
continue;
}
String[] split = line.trim().split(",|\\s+");
double[] series = new double[split.length];
for (int i = 0; i < split.length; i++) {
series[i] = Double.valueOf(split[i].trim()).doubleValue();
}
res.add(series);
}
br.close();
return res;
}
private static Double parseValue(String string) {
Double res = Double.NaN;
try {
Double r = Double.valueOf(string);
res = r;
}
catch (NumberFormatException e) {
assert true;
}
return res;
}
}