package edu.hawaii.jmotif.performance.cbf; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; import cc.mallet.util.Randoms; import edu.hawaii.jmotif.experiment.cbf.CBFGenerator; import edu.hawaii.jmotif.performance.UCRGenericClassifier; import edu.hawaii.jmotif.text.SAXCollectionStrategy; import edu.hawaii.jmotif.text.TextUtils; import edu.hawaii.jmotif.text.WordBag; /** * Helper-runner for CBF test. * * @author psenin * */ public class UCRcbfKNNDamaged extends UCRGenericClassifier { // num of threads to use // private static final int THREADS_NUM = 4; // data // private static final String TRAINING_DATA = "data/CBF/CBF_TRAIN"; private static final String TEST_DATA = "data/CBF/CBF_TEST"; // output prefix // private static final String outputPrefix = "cbf_loocv_generated_1"; // SAX parameters to use // private static final int WINDOW_MIN = 20; private static final int WINDOW_MAX = 70; private static final int WINDOW_INCREMENT = 1; private static final int PAA_MIN = 4; private static final int PAA_MAX = 8; private static final int PAA_INCREMENT = 1; private static final int ALPHABET_MIN = 3; private static final int ALPHABET_MAX = 10; private static final int ALPHABET_INCREMENT = 1; // leave out parameters // private static final int LEAVE_OUT_NUM = 1; private static final int SERIES_LENGTH = 128; private static Randoms randoms; private UCRcbfKNNDamaged() { super(); } /** * @param args * @throws Exception */ public static void main(String[] args) throws Exception { randoms = new Randoms(); // configuring strategy // SAXCollectionStrategy strategy = SAXCollectionStrategy.NOREDUCTION; String strategyPrefix = "noreduction"; if (args.length > 0) { String strategyP = args[0]; if ("EXACT".equalsIgnoreCase(strategyP)) { strategy = SAXCollectionStrategy.EXACT; strategyPrefix = "exact"; } if ("CLASSIC".equalsIgnoreCase(strategyP)) { strategy = SAXCollectionStrategy.CLASSIC; strategyPrefix = "classic"; } } consoleLogger.fine("strategy: " + strategyPrefix + ", leaving out: " + LEAVE_OUT_NUM); // make up window sizes int[] window_sizes = makeArray(WINDOW_MIN, WINDOW_MAX, WINDOW_INCREMENT); // make up paa sizes int[] paa_sizes = makeArray(PAA_MIN, PAA_MAX, PAA_INCREMENT); // make up alphabet sizes int[] alphabet_sizes = makeArray(ALPHABET_MIN, ALPHABET_MAX, ALPHABET_INCREMENT); // reading training and test collections // // Map<String, List<double[]>> trainData = UCRUtils.readUCRData(TRAINING_DATA); Map<String, List<double[]>> trainData = generateSample(60); consoleLogger.fine("trainData classes: " + trainData.size() + ", series length: " + trainData.entrySet().iterator().next().getValue().get(0).length); for (Entry<String, List<double[]>> e : trainData.entrySet()) { consoleLogger.fine(" training class: " + e.getKey() + " series: " + e.getValue().size()); } int totalTestSample = 0; // Map<String, List<double[]>> testData = UCRUtils.readUCRData(TEST_DATA); Map<String, List<double[]>> testData = generateSample(200); testData = damage(testData, 0.4, 0.167); consoleLogger.fine("testData classes: " + testData.size()); for (Entry<String, List<double[]>> e : testData.entrySet()) { consoleLogger.fine(" test class: " + e.getKey() + " series: " + e.getValue().size()); totalTestSample = totalTestSample + e.getValue().size(); } // here is a loop over SAX parameters, strategy is fixed // for (int windowSize : window_sizes) { for (int paaSize : paa_sizes) { for (int alphabetSize : alphabet_sizes) { // make sure to brake if PAA greater than window if (windowSize < paaSize + 1) { continue; } // making training bags collection List<WordBag> bags = TextUtils.labeledSeries2WordBags(trainData, paaSize, alphabetSize, windowSize, SAXCollectionStrategy.NOREDUCTION); // getting TFIDF done HashMap<String, HashMap<String, Double>> tfidf = TextUtils.computeTFIDF(bags); // System.out.println(TextUtils.bagsToTable(bags)); // normalize vectors tfidf = TextUtils.normalizeToUnitVectors(tfidf); // System.out.println(TextUtils.tfidfToTable(tfidf)); // classifying int testSampleSize = 0; int positiveTestCounter = 0; for (String label : tfidf.keySet()) { List<double[]> testD = testData.get(label); for (double[] series : testD) { positiveTestCounter = positiveTestCounter + TextUtils.classify(label, series, tfidf, paaSize, alphabetSize, windowSize, strategy); testSampleSize++; } } // accuracy and error double accuracy = (double) positiveTestCounter / (double) testSampleSize; double error = 1.0d - accuracy; // report results System.out.println(windowSize + COMMA + paaSize + COMMA + alphabetSize + COMMA + accuracy + COMMA + error); } } } } private static Map<String, List<double[]>> generateSample(int sampleSize) { Map<String, List<double[]>> res = new HashMap<String, List<double[]>>(); // ticks int[] t = new int[SERIES_LENGTH]; for (int i = 0; i < SERIES_LENGTH; i++) { t[i] = i; } // cylinder sample List<double[]> cylinders = new ArrayList<double[]>(); for (int i = 0; i < sampleSize; i++) { cylinders.add(CBFGenerator.cylinder(t)); } res.put("1", cylinders); // bell sample List<double[]> bells = new ArrayList<double[]>(); for (int i = 0; i < sampleSize; i++) { bells.add(CBFGenerator.bell(t)); } res.put("2", bells); // funnel sample List<double[]> funnels = new ArrayList<double[]>(); for (int i = 0; i < sampleSize; i++) { funnels.add(CBFGenerator.funnel(t)); } res.put("3", funnels); return res; } private static Map<String, List<double[]>> damage(Map<String, List<double[]>> trainData, double damagedIntervalLength, double noiseStandardDeviation) { Map<String, List<double[]>> res = new HashMap<String, List<double[]>>(); for (Entry<String, List<double[]>> referenceSet : trainData.entrySet()) { List<double[]> newData = new ArrayList<double[]>(); int seriesCounter = 0; for (double[] referenceSeries : referenceSet.getValue()) { // if (seriesCounter > 3 && seriesCounter < 5) { // System.out.println(referenceSet.getKey() + " = " + Arrays.toString(referenceSeries)); // } int noiseStart = Double.valueOf( Math.floor(randoms.nextUniform(0D, 128D * (1 - damagedIntervalLength)))).intValue(); int noiseEnd = noiseStart + Double.valueOf(128D * damagedIntervalLength).intValue(); for (int i = noiseStart; i < noiseEnd; i++) { referenceSeries[i] = randoms.nextGaussian(0, noiseStandardDeviation); } // if (seriesCounter > 3 && seriesCounter < 5) { // System.out.println(referenceSet.getKey() + "<-" + Arrays.toString(referenceSeries)); // } newData.add(referenceSeries); seriesCounter++; } res.put(referenceSet.getKey(), newData); } return res; } }