package edu.hawaii.jmotif.performance.digits; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; import edu.hawaii.jmotif.performance.KNNStackEntry; import edu.hawaii.jmotif.performance.UCRGenericClassifier; import edu.hawaii.jmotif.performance.UCRUtils; import edu.hawaii.jmotif.text.SAXCollectionStrategy; import edu.hawaii.jmotif.text.TextUtils; import edu.hawaii.jmotif.text.WordBag; /** * Helper-runner for test. * * @author psenin * */ public class ConfusionMatrixGenerator extends UCRGenericClassifier { // data locations // private static final String TRAINING_DATA = "data/digits/digits_reduced_1000.csv"; private static final String TEST_DATA = "data/digits/digits_train.txt"; // private static final String TRAINING_DATA = "data/digits/digits_reduced_50.csv"; // private static final String TEST_DATA = "data/digits/digits_reduced_50.csv"; // SAX parameters to try // private static final int[][] params = { { 177, 16, 3, NOREDUCTION }, }; /** * Runnable. * * @throws Exception if error occurs. */ public static void main(String[] args) throws Exception { // making training and test collections // Map<String, List<double[]>> trainData = UCRUtils.readUCRData(TRAINING_DATA); Map<String, List<double[]>> testData = UCRUtils.readUCRData(TEST_DATA); HashMap<String, HashMap<String, Integer>> confusionMatrix = new HashMap<String, HashMap<String, Integer>>(); for (Integer i = 0; i < 10; i++) { HashMap<String, Integer> map = new HashMap<String, Integer>(); for (Integer k = 0; k < 10; k++) { map.put(String.valueOf(k), 0); } confusionMatrix.put(String.valueOf(i), map); } // iterate over parameters // for (int[] p : params) { // converting back from easy encoding int WINDOW_SIZE = p[0]; int PAA_SIZE = p[1]; int ALPHABET_SIZE = p[2]; SAXCollectionStrategy strategy = SAXCollectionStrategy.CLASSIC; if (EXACT == p[3]) { strategy = SAXCollectionStrategy.EXACT; } else if (NOREDUCTION == p[3]) { strategy = SAXCollectionStrategy.NOREDUCTION; } // making training bags collection List<WordBag> bags = TextUtils.labeledSeries2WordBags(trainData, PAA_SIZE, ALPHABET_SIZE, WINDOW_SIZE, strategy); // getting TFIDF done HashMap<String, HashMap<String, Double>> tfidf = TextUtils.computeTFIDF(bags); int seriesCounter = 1; for (Entry<String, List<double[]>> te : testData.entrySet()) { String classId = te.getKey(); for (double[] series : te.getValue()) { WordBag test = TextUtils.seriesToWordBag("test", series, params[0]); ArrayList<KNNStackEntry<String, Double>> cosines = new ArrayList<KNNStackEntry<String, Double>>(); // get cosines computed for (Entry<String, HashMap<String, Double>> e : tfidf.entrySet()) { double dist = TextUtils.cosineSimilarity(test, e.getValue()); cosines.add(new KNNStackEntry<String, Double>(e.getKey(), dist)); } Collections.sort(cosines, new Comparator<KNNStackEntry<String, Double>>() { @Override public int compare(KNNStackEntry<String, Double> arg0, KNNStackEntry<String, Double> arg1) { return arg0.getValue().compareTo(arg1.getValue()); } }); // report our findings if (classId.equalsIgnoreCase(cosines.get(9).getKey())) { System.out.println("correct: series of class " + classId + " classified as class " + cosines.get(9).getKey() + ", distance " + cosines.get(9).getValue() + ", second class: " + cosines.get(8).getKey() + ", distance: " + cosines.get(8).getValue()); } else { System.out.println("incorrect: series of class " + classId + " classified as class " + cosines.get(9).getKey() + ", distance " + cosines.get(9).getValue() + ", second class: " + cosines.get(8).getKey() + ", distance: " + cosines.get(8).getValue() + ", distance to correct class " + getDistanceTo(classId, cosines)); Integer oldCounter = confusionMatrix.get(classId).get(cosines.get(9).getKey()); confusionMatrix.get(classId).put(cosines.get(9).getKey(), oldCounter + 1); } seriesCounter++; } } } System.out.println("Confusion data :"); for (Integer i = 0; i < 10; i++) { HashMap<String, Integer> map = confusionMatrix.get(String.valueOf(i)); System.out.print(i + ","); for (Integer k = 0; k < 10; k++) { System.out.print(map.get(String.valueOf(k)) + ","); } System.out.println(); } } private static double getDistanceTo(String classId, ArrayList<KNNStackEntry<String, Double>> cosines) { for (KNNStackEntry<String, Double> e : cosines) { if (classId.equalsIgnoreCase(e.getKey())) { return e.getValue(); } } return -1.; } }