package edu.hawaii.jmotif.performance.digits; import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; import cc.mallet.util.Randoms; import edu.hawaii.jmotif.performance.KNNStackEntry; import edu.hawaii.jmotif.performance.UCRGenericClassifier; import edu.hawaii.jmotif.performance.UCRUtils; import edu.hawaii.jmotif.text.SAXCollectionStrategy; import edu.hawaii.jmotif.text.TextUtils; import edu.hawaii.jmotif.text.WordBag; /** * Helper-runner for test. * * @author psenin * */ public class VectorExplorer extends UCRGenericClassifier { // data locations // private static final String TRAINING_DATA = "data/digits/digits_reduced_400.csv"; private static final String TEST_DATA = "data/digits/test.csv"; // SAX parameters to try // private static final int[][] params = { { 177, 16, 3, NOREDUCTION }, }; private static final Randoms randoms = new Randoms(); /** * Runnable. * * @throws Exception if error occurs. */ public static void main(String[] args) throws Exception { // making training and test collections // Map<String, List<double[]>> trainData = UCRUtils.readUCRData(TRAINING_DATA); Map<String, List<double[]>> testData = UCRUtils.readUCRData(TEST_DATA); // iterate over parameters // for (int[] p : params) { // converting back from easy encoding int WINDOW_SIZE = p[0]; int PAA_SIZE = p[1]; int ALPHABET_SIZE = p[2]; SAXCollectionStrategy strategy = SAXCollectionStrategy.CLASSIC; if (EXACT == p[3]) { strategy = SAXCollectionStrategy.EXACT; } else if (NOREDUCTION == p[3]) { strategy = SAXCollectionStrategy.NOREDUCTION; } // making training bags collection List<WordBag> allBags = TextUtils.labeledSeries2WordBags(trainData, PAA_SIZE, ALPHABET_SIZE, WINDOW_SIZE, strategy); List<WordBag> bags = new ArrayList<WordBag>(); WordBag others = new WordBag("other"); for (WordBag bag : allBags) { if ("0".equalsIgnoreCase(bag.getLabel())) { bags.add(bag); } else if ("1".equalsIgnoreCase(bag.getLabel())) { bags.add(bag); } else { others.mergeWith(bag); } } bags.add(others); // getting TFIDF done HashMap<String, HashMap<String, Double>> tfidf = TextUtils.computeTFIDF(bags); // for (int i = 0; i < 10; i++) { // for (int j = i + 1; j < 10; j++) { // double value = cosine(tfidf.get(String.valueOf(i)), tfidf.get(String.valueOf(j))); // System.out.println("distance " + i + ", " + j + ": " + value); // } // } int seriesCounter = 1; for (Entry<String, List<double[]>> te : testData.entrySet()) { String classId = te.getKey(); for (double[] series : te.getValue()) { WordBag test = TextUtils.seriesToWordBag("test", series, params[0]); ArrayList<KNNStackEntry<String, Double>> cosines = new ArrayList<KNNStackEntry<String, Double>>(); // get cosines computed for (Entry<String, HashMap<String, Double>> e : tfidf.entrySet()) { double dist = TextUtils.cosineSimilarity(test, e.getValue()); cosines.add(new KNNStackEntry<String, Double>(e.getKey(), dist)); } Collections.sort(cosines, new Comparator<KNNStackEntry<String, Double>>() { @Override public int compare(KNNStackEntry<String, Double> arg0, KNNStackEntry<String, Double> arg1) { return arg0.getValue().compareTo(arg1.getValue()); } }); // report our findings if (classId.equalsIgnoreCase(cosines.get(2).getKey())) { // System.out.println("correct: series of class " + classId + " classified as class " // + cosines.get(2).getKey() + ", distance " + cosines.get(2).getValue() // + ", second class: " + cosines.get(1).getKey() + ", distance: " // + cosines.get(1).getValue()); System.out.println("1," + classId + "," + cosines.get(2).getKey() + "," + cosines.get(2).getValue() + cosines.get(1).getKey() + "," + cosines.get(1).getValue() + ",NA"); } else { // System.out.println("incorrect: series of class " + classId + " classified as class " // + cosines.get(2).getKey() + ", distance " + cosines.get(2).getValue() // + ", second class: " + cosines.get(1).getKey() + ", distance: " // + cosines.get(1).getValue() + ", distance to correct class " // + getDistanceTo(classId, cosines)); System.out.println("0," + classId + "," + cosines.get(2).getKey() + "," + cosines.get(2).getValue() + "," + cosines.get(1).getKey() + "," + cosines.get(1).getValue() + "," + getDistanceTo(classId, cosines)); } seriesCounter++; } } } } private static double getDistanceTo(String classId, ArrayList<KNNStackEntry<String, Double>> cosines) { for (KNNStackEntry<String, Double> e : cosines) { if (classId.equalsIgnoreCase(e.getKey())) { return e.getValue(); } } return -1.; } private static double cosine(HashMap<String, Double> data1, HashMap<String, Double> data2) { // sanity word order check if (!(data2.keySet().containsAll(data1.keySet())) || !(data2.keySet().size() == data1.keySet().size())) { throw new RuntimeException("COSINE SIMILARITY ERROR: word sets are different in length!"); } double[] vector1 = new double[data1.size()]; double[] vector2 = new double[data2.size()]; int i = 0; for (String s : data1.keySet()) { vector1[i] = data1.get(s); vector2[i] = data2.get(s); i++; } double numerator = TextUtils.dotProduct(vector1, vector2); double denominator = TextUtils.magnitude(vector1) * TextUtils.magnitude(vector2); return numerator / denominator; } private static List<double[]> readTestData(String fileName) throws NumberFormatException, IOException { List<double[]> res = new ArrayList<double[]>(); BufferedReader br = new BufferedReader(new FileReader(new File(fileName))); String line = ""; while ((line = br.readLine()) != null) { if (line.trim().length() == 0) { continue; } String[] split = line.trim().split(",|\\s+"); double[] series = new double[split.length]; for (int i = 0; i < split.length; i++) { series[i] = Double.valueOf(split[i].trim()).doubleValue(); } res.add(series); } br.close(); return res; } private static Double parseValue(String string) { Double res = Double.NaN; try { Double r = Double.valueOf(string); res = r; } catch (NumberFormatException e) { assert true; } return res; } }