package edu.hawaii.jmotif.performance.digits; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.concurrent.Callable; import com.dtw.TimeWarpInfo; import com.timeseries.TimeSeries; import com.timeseries.TimeSeriesPoint; import com.util.DistanceFunction; import com.util.DistanceFunctionFactory; import edu.hawaii.jmotif.performance.KNNStackEntry; public class DTWknnJob implements Callable<String> { private double[] series; private int seriesCounter; private Map<String, double[]> trainData; public DTWknnJob(double[] series, int seriesCounter, Map<String, double[]> trainData) { this.series = series; this.seriesCounter = seriesCounter; this.trainData = trainData; } @Override public String call() throws Exception { return getVotes(getNeighbors(series, trainData)); } private String getVotes(List<KNNStackEntry<String, Double>> neighbors) { String[] res = new String[neighbors.size()]; int i = 0; for (KNNStackEntry<String, Double> e : neighbors) { res[i] = e.getKey(); i++; } return "ok_ " + Arrays.toString(res) + " : " + seriesCounter + "," + getVote(res) + "," + res[9]; } private int getVote(String[] res) { int[] votes = new int[10]; for (String s : res) { votes[Integer.valueOf(s.substring(0,s.indexOf("_")))]++; } int maxVotes = -1; int maxIdx = 0; for (int i = 0; i < 10; i++) { if (votes[i] > maxVotes) { maxVotes = votes[i]; maxIdx = i; } } return maxIdx; } private ArrayList<KNNStackEntry<String, Double>> getNeighbors(double[] series, Map<String, double[]> trainData) { ArrayList<KNNStackEntry<String, Double>> res = new ArrayList<KNNStackEntry<String, Double>>(); for (Entry<String, double[]> e : trainData.entrySet()) { double dist = getDist(series, e.getValue()); if (res.size() < 10) { res.add(new KNNStackEntry<String, Double>(e.getKey(), dist)); } else { checkDist(res, e.getKey(), dist); } } return res; } private void checkDist(ArrayList<KNNStackEntry<String, Double>> res, String label, double dist) { Collections.sort(res, new Comparator<KNNStackEntry<String, Double>>() { @Override public int compare(KNNStackEntry<String, Double> arg0, KNNStackEntry<String, Double> arg1) { return arg0.getValue().compareTo(arg1.getValue()); } }); if (res.get(9).getValue() > dist) { res.remove(9); res.add(new KNNStackEntry<String, Double>(label, dist)); } } private double getDist(double[] series1, double[] series2) { // final TimeSeries tsI = new TimeSeries(args[0], false, false, ','); final TimeSeries tsI = new TimeSeries(1); for (int i = 0; i < series1.length; i++) { tsI.addLast(i, new TimeSeriesPoint(new double[] { series1[i] })); } // final TimeSeries tsJ = new TimeSeries(args[1], false, false, ','); final TimeSeries tsJ = new TimeSeries(1); for (int i = 0; i < series2.length; i++) { tsJ.addLast(i, new TimeSeriesPoint(new double[] { series2[i] })); } DistanceFunction distFn = DistanceFunctionFactory.getDistFnByName("EuclideanDistance"); final TimeWarpInfo info = com.dtw.FastDTW.getWarpInfoBetween(tsI, tsJ, 10, distFn); return info.getDistance(); } }