/* * copyright: Anthony Bagnall * * */ package weka.classifiers.trees.shapelet_trees; import java.util.ArrayList; import weka.core.*; import weka.core.Instances; import java.util.TreeMap; import java.util.TreeSet; import java.util.ArrayList; import java.util.Calendar; import java.util.GregorianCalendar; import java.util.Collections; import java.io.FileReader; import java.io.FileWriter; import weka.classifiers.*; public class FStatShapeletTreeWithInfoGain extends AbstractClassifier{ private ShapeletNode root; private String logFileName; private int minLength, maxLength; public FStatShapeletTreeWithInfoGain(String logFileName) throws Exception { this.root = new ShapeletNode(); this.logFileName = logFileName; FileWriter fw = new FileWriter(logFileName); fw.close(); } public void setShapeletMinMaxLength(int minLength, int maxLength){ this.minLength = minLength; this.maxLength = maxLength; } @Override public void buildClassifier(Instances data) throws Exception { if(minLength < 1 || maxLength < 1){ throw new Exception("Shapelet minimum or maximum length is incorrectly specified!"); } root.initialiseNode(data, minLength, maxLength, 0); } @Override public double classifyInstance(Instance instance) { return root.classifyInstance(instance); } private Shapelet getRootShapelet() { return this.root.shapelet; } private class ShapeletNode { private ShapeletNode leftNode; private ShapeletNode rightNode; private double classDecision; private Shapelet shapelet; public ShapeletNode() { leftNode = null; rightNode = null; classDecision = -1; } public void initialiseNode(Instances data, int minShapeletLength, int maxShapeletLength, int level) throws Exception { FileWriter fw = new FileWriter(logFileName, true); fw.append("level:" + level + ", numInstances:" + data.numInstances() + "\n"); fw.close(); // 1. check whether this is a leaf node with only one class present double firstClassValue = data.instance(0).classValue(); boolean oneClass = true; for (int i = 1; i < data.numInstances(); i++) { if (data.instance(i).classValue() != firstClassValue) { oneClass = false; break; } } if (oneClass == true) { this.classDecision = firstClassValue; // no need to find shapelet, base case // System.out.println("base case"); fw = new FileWriter(logFileName, true); fw.append("class decision here: " + firstClassValue + "\n"); fw.close(); } else { // recursively call method to create left and right children nodes try { // 1. find the best shapelet to split the data this.shapelet = findBestShapelet(data, minShapeletLength, maxShapeletLength); // 2. split the data using the shapelet and create new data sets double dist; // System.out.println("Threshold:"+shapelet.getThreshold()); // System.out.println("length:"+shapelet.getLength()); ArrayList<Instance> splitLeft = new ArrayList<Instance>(); ArrayList<Instance> splitRight = new ArrayList<Instance>(); for (int i = 0; i < data.numInstances(); i++) { dist = subsequenceDistance(this.shapelet.content, data.instance(i).toDoubleArray()); // System.out.println("dist:"+dist); // if(dist< shapelet.medianDistance){ if (dist < shapelet.splitThresh) { splitLeft.add(data.instance(i)); // System.out.println("gone left"); } else { splitRight.add(data.instance(i)); // System.out.println("gone right"); } } // write to file here!!!! fw = new FileWriter(logFileName, true); fw.append("seriesId, startPos, length, infoGain, splitThresh\n"); fw.append(this.shapelet.seriesId + "," + this.shapelet.startPos + "," + this.shapelet.content.length + "," + this.shapelet.fStat + "," + this.shapelet.splitThresh + "\n"); for (int j = 0; j < this.shapelet.content.length; j++) { fw.append(this.shapelet.content[j] + ","); } fw.append("\n"); fw.close(); System.out.println("shapelet completed at:" + System.nanoTime()); // System.out.println("leftSize:"+splitLeft.size()); // System.out.println("leftRight:"+splitRight.size()); // 5. initialise and recursively compute children nodes leftNode = new ShapeletNode(); rightNode = new ShapeletNode(); // System.out.println("SplitLeft:"); //*** new condition! added for mixing stats because Electric data is a bit nuts- if no 'good' shapelet exists and it splits nothing, set class val to most common class val if (splitLeft.isEmpty() || splitRight.isEmpty()) { TreeMap<Double, Integer> classesForEscape = getClassDistributions(data); double bestKey = data.instance(0).classValue(); int bestTotal = 0; for (Double d : classesForEscape.keySet()) { if (classesForEscape.get(d) > bestTotal) { bestTotal = classesForEscape.get(d); bestKey = d; } } this.classDecision = bestKey; // no need to find shapelet, base case // System.out.println("base case"); fw = new FileWriter(logFileName, true); fw.append("PRUNED class decision here: " + bestKey + "\n"); fw.close(); } else { Instances leftInstances = new Instances(data, splitLeft.size()); for (int i = 0; i < splitLeft.size(); i++) { leftInstances.add(splitLeft.get(i)); } Instances rightInstances = new Instances(data, splitRight.size()); for (int i = 0; i < splitRight.size(); i++) { rightInstances.add(splitRight.get(i)); } fw = new FileWriter(logFileName, true); fw.append("left size under level " + level + ": " + leftInstances.numInstances() + "\n"); fw.close(); leftNode.initialiseNode(leftInstances, minShapeletLength, maxShapeletLength, (level + 1)); // System.out.println("SplitRight:"); fw = new FileWriter(logFileName, true); fw.append("right size under level " + level + ": " + rightInstances.numInstances() + "\n"); fw.close(); rightNode.initialiseNode(rightInstances, minShapeletLength, maxShapeletLength, (level + 1)); } } catch (Exception e) { System.out.println("Problem initialising tree node: " + e); e.printStackTrace(); } } } public double classifyInstance(Instance instance) { if (this.leftNode == null) { return this.classDecision; } else { double distance = subsequenceDistance(this.shapelet.content, instance); if (distance < this.shapelet.splitThresh) { return leftNode.classifyInstance(instance); } else { return rightNode.classifyInstance(instance); } } } } //# public double timingForSingleShapelet(Instances data, int minShapeletLength, int maxShapeletLength) { long startTime = System.nanoTime(); this.findBestShapelet(data, minShapeletLength, maxShapeletLength); long finishTime = System.nanoTime(); return (double)(finishTime - startTime) / 1000000000.0; } // edited from findBestKShapeletsCached private Shapelet findBestShapelet(Instances data, int minShapeletLength, int maxShapeletLength) { Shapelet bestShapelet = null; TreeMap<Double, Integer> classDistributions = getClassDistributions(data); // used to calc info gain //for all time series System.out.println("Processing data: "); for (int i = 0; i < data.numInstances(); i++) { double[] wholeCandidate = data.instance(i).toDoubleArray(); // for all lengths for (int length = minShapeletLength; length <= maxShapeletLength; length++) { //for all possible starting positions of that length for (int start = 0; start <= wholeCandidate.length - length - 1; start++) { //-1 = avoid classVal - handle later for series with no class val // CANDIDATE ESTABLISHED - got original series, length and starting position // extract relevant part into a double[] for processing double[] candidate = new double[length]; for (int m = start; m < start + length; m++) { candidate[m - start] = wholeCandidate[m]; } candidate = zNorm(candidate, false); Shapelet candidateShapelet = checkCandidate(candidate, data, i, start, classDistributions); if (bestShapelet == null || candidateShapelet.compareTo(bestShapelet) < 0) { bestShapelet = candidateShapelet; // System.out.println("here"); } } } } bestShapelet.calculateBestSplitPoint(classDistributions); //print out the k best shapes and then return // System.out.println("Shapelet No, Series ID, Start, Length, InfogGain, Gap,"); return bestShapelet; } /** * * @param shapelets the input Shapelets to remove self similar Shapelet objects from * @return a copy of the input ArrayList with self-similar shapelets removed */ // private static ArrayList<Shapelet> removeSelfSimilar(ArrayList<Shapelet> shapelets){ // // return a new pruned array list - more efficient than removing // // self-similar entries on the fly and constantly reindexing // ArrayList<Shapelet> outputShapelets = new ArrayList<Shapelet>(); // boolean[] selfSimilar = new boolean[shapelets.size()]; // // // to keep tract of self similarity - assume nothing is similar to begin with // for(int i = 0; i < shapelets.size(); i++){ // selfSimilar[i] = false; // } // // for(int i = 0; i < shapelets.size();i++){ // if(selfSimilar[i]==false){ // outputShapelets.add(shapelets.get(i)); // for(int j = i+1; j < shapelets.size(); j++){ // if(selfSimilar[j]==false && selfSimilarity(shapelets.get(i),shapelets.get(j))){ // no point recalc'ing if already self similar to something // selfSimilar[j] = true; // } // } // } // } // return outputShapelets; // } /** * * @param k the maximum number of shapelets to be returned after combining the two lists * @param kBestSoFar the (up to) k best shapelets that have been observed so far, passed in to combine with shapelets from a new series * @param timeSeriesShapelets the shapelets taken from a new series that are to be merged in descending order of fitness with the kBestSoFar * @return an ordered ArrayList of the best k (or less) Shapelet objects from the union of the input ArrayLists */ private ArrayList<Shapelet> combine(int k, ArrayList<Shapelet> kBestSoFar, ArrayList<Shapelet> timeSeriesShapelets) { ArrayList<Shapelet> newBestSoFar = new ArrayList<Shapelet>(); for (int i = 0; i < timeSeriesShapelets.size(); i++) { kBestSoFar.add(timeSeriesShapelets.get(i)); } Collections.sort(kBestSoFar); if (kBestSoFar.size() < k) { return kBestSoFar; // no need to return up to k, as there are not k shapelets yet } for (int i = 0; i < k; i++) { newBestSoFar.add(kBestSoFar.get(i)); } return newBestSoFar; } /** * * @param data the input data set that the class distributions are to be derived from * @return a TreeMap<Double, Integer> in the form of <Class Value, Frequency> */ private static TreeMap<Double, Integer> getClassDistributions(Instances data) { TreeMap<Double, Integer> classDistribution = new TreeMap<Double, Integer>(); double classValue; for (int i = 0; i < data.numInstances(); i++) { classValue = data.instance(i).classValue(); boolean classExists = false; for (Double d : classDistribution.keySet()) { if (d == classValue) { int temp = classDistribution.get(d); temp++; classDistribution.put(classValue, temp); classExists = true; } } if (classExists == false) { classDistribution.put(classValue, 1); } } return classDistribution; } /** * * @param candidate the data from the candidate Shapelet * @param data the entire data set to compare the candidate to * @param data the entire data set to compare the candidate to * @return a TreeMap<Double, Integer> in the form of <Class Value, Frequency> */ private static Shapelet checkCandidate(double[] candidate, Instances data, int seriesId, int startPos, TreeMap classDistribution) { // create orderline by looping through data set and calculating the subsequence // distance from candidate to all data, inserting in order. ArrayList<OrderLineObj> orderline = new ArrayList<OrderLineObj>(); for (int i = 0; i < data.numInstances(); i++) { double distance = subsequenceDistance(candidate, data.instance(i)); double classVal = data.instance(i).classValue(); // NOTE: MAKE SURE THE ORDERLINE IS SORTED IN THE QUALITY MEASURE orderline.add(new OrderLineObj(distance, classVal)); } Shapelet shapelet = new Shapelet(candidate, seriesId, startPos); // shapelet.calculateMoodsMedian(orderline, classDistribution); shapelet.calculateMoodsMedianTree(orderline, classDistribution); return shapelet; } /** * * @param candidate * @param timeSeriesIns * @return */ public static double subsequenceDistance(double[] candidate, Instance timeSeriesIns) { double[] timeSeries = timeSeriesIns.toDoubleArray(); return subsequenceDistance(candidate, timeSeries); } public static double subsequenceDistance(double[] candidate, double[] timeSeries) { // double[] timeSeries = timeSeriesIns.toDoubleArray(); double bestSum = Double.MAX_VALUE; double sum = 0; double[] subseq; // for all possible subsequences of two for (int i = 0; i <= timeSeries.length - candidate.length - 1; i++) { sum = 0; // get subsequence of two that is the same lenght as one subseq = new double[candidate.length]; for (int j = i; j < i + candidate.length; j++) { subseq[j - i] = timeSeries[j]; } subseq = zNorm(subseq, false); // Z-NORM HERE for (int j = 0; j < candidate.length; j++) { sum += (candidate[j] - subseq[j]) * (candidate[j] - subseq[j]); } if (sum < bestSum) { bestSum = sum; } } return (1.0 / candidate.length * bestSum); } /** * * @param input * @param classValOn * @return */ public static double[] zNorm(double[] input, boolean classValOn) { double mean; double stdv; double classValPenalty = 0; if (classValOn) { classValPenalty = 1; } double[] output = new double[input.length]; double seriesTotal = 0; for (int i = 0; i < input.length - classValPenalty; i++) { seriesTotal += input[i]; } mean = seriesTotal / (input.length - classValPenalty); stdv = 0; for (int i = 0; i < input.length - classValPenalty; i++) { stdv += (input[i] - mean) * (input[i] - mean); } stdv = stdv / input.length - classValPenalty; stdv = Math.sqrt(stdv); for (int i = 0; i < input.length - classValPenalty; i++) { output[i] = (input[i] - mean) / stdv; } if (classValOn == true) { output[output.length - 1] = input[input.length - 1]; } return output; } /** * * @param fileName * @return */ public static Instances loadData(String fileName) { Instances data = null; try { FileReader r; r = new FileReader(fileName); data = new Instances(r); data.setClassIndex(data.numAttributes() - 1); } catch (Exception e) { System.out.println(" Error =" + e + " in method loadData"); } return data; } // private static boolean selfSimilarity(Shapelet shapelet, Shapelet candidate){ // if(candidate.seriesId == shapelet.seriesId){ // if(candidate.startPos >= shapelet.startPos && candidate.startPos < shapelet.startPos + shapelet.content.length){ //candidate starts within exisiting shapelet // return true; // } // if(shapelet.startPos >= candidate.startPos && shapelet.startPos < candidate.startPos + candidate.content.length){ // return true; // } // } // return false; // } private static class Shapelet implements Comparable<Shapelet> { private double[] content; private int seriesId; private int startPos; private double fStat; private ArrayList<OrderLineObj> orderline; private double splitThresh; private double separationGap; private Shapelet(double[] content, int seriesId, int startPos) { this.content = content; this.seriesId = seriesId; this.startPos = startPos; } private Shapelet(double[] content) { this.content = content; } /** * A method to calculate the quality of a Shapelet, given the orderline produced by computing the distance * from the shapelet to each element of the dataset. * * @param orderline the pre-computed set of distances for a dataset to a single shapelet * @param classDistribution the distibution of all possible class values in the orderline * @return a measure of shapelet quality according to f-stat */ public void calculateMoodsMedian(ArrayList<OrderLineObj> orderline, TreeMap<Double, Integer> classDistribution) { this.orderline = orderline; Collections.sort(orderline); int numClasses = classDistribution.size(); int numInstances = orderline.size(); double[] sums = new double[numClasses]; double[] sumsSquared = new double[numClasses]; double[] sumOfSquares = new double[numClasses]; for (int i = 0; i < numClasses; i++) { sums[i] = 0; sumsSquared[i] = 0; sumOfSquares[i] = 0; } for (int i = 0; i < orderline.size(); i++) { int c = (int) orderline.get(i).classVal; double thisDist = orderline.get(i).distance; System.out.println("c = "+c+", numClasses = "+numClasses); sums[c] += thisDist; sumOfSquares[c] += thisDist * thisDist; } // error below here for (int i = 0; i < numClasses; i++) { sumsSquared[i] = sums[i] * sums[i]; } double ssTotal = 0; double part1 = 0; double part2 = 0; for (int i = 0; i < numClasses; i++) { part1 += sumOfSquares[i]; part2 += sums[i]; } part2 *= part2; part2 /= numInstances; ssTotal = part1 - part2; double ssAmoung = 0; part1 = 0; part2 = 0; for (int i = 0; i < numClasses; i++) { part1 += (double) sumsSquared[i] / classDistribution.get((double) i);//.data[i].size(); part2 += sums[i]; } ssAmoung = part1 - (part2 * part2) / numInstances; double ssWithin = ssTotal - ssAmoung; int dfAmoung = numClasses - 1; int dfWithin = numInstances - numClasses; double msAmoung = ssAmoung / dfAmoung; double msWithin = ssWithin / dfWithin; double f = msAmoung / msWithin; this.fStat = f; } public void calculateMoodsMedianTree(ArrayList<OrderLineObj> orderline, TreeMap<Double, Integer> classDistribution) { this.orderline = orderline; Collections.sort(orderline); int numClasses = classDistribution.size(); int numInstances = orderline.size(); double[] sums = new double[numClasses]; double[] sumsSquared = new double[numClasses]; double[] sumOfSquares = new double[numClasses]; // could be more efficient, but added in to adapt class distribution for the tree implementation // original implementation used numClasses as an index which is fine for filter, but not nec. for a tree // i.e. numClasses might = 2, but the class vals could be {0,3} if the tree has already split before double[] classValuesArray = new double[numClasses]; double[] classValuesArrayCounts = new double[numClasses]; int index = 0; for(Double d:classDistribution.keySet()){ classValuesArray[index] = d; classValuesArrayCounts[index] = classDistribution.get(d); index++; } for (int i = 0; i < numClasses; i++) { sums[i] = 0; sumsSquared[i] = 0; sumOfSquares[i] = 0; } for (int i = 0; i < orderline.size(); i++) { // c is where the problem happened, as this could be higher than numClasses if a tree has already branched with a multi-class problem int c = (int) orderline.get(i).classVal; double thisDist = orderline.get(i).distance; for(int j = 0; j < numClasses; j++){ if(classValuesArray[j]==c){ sums[j] += thisDist; sumOfSquares[j] += thisDist * thisDist; continue; // a bit hacky, but saves going around any remaining class vals } } // left in to demonstrate how it is different from the original filter implementation // sums[c] += thisDist; // sumOfSquares[c] += thisDist * thisDist; } for (int i = 0; i < numClasses; i++) { sumsSquared[i] = sums[i] * sums[i]; } double ssTotal = 0; double part1 = 0; double part2 = 0; for (int i = 0; i < numClasses; i++) { part1 += sumOfSquares[i]; part2 += sums[i]; } part2 *= part2; part2 /= numInstances; ssTotal = part1 - part2; double ssAmoung = 0; part1 = 0; part2 = 0; for (int i = 0; i < numClasses; i++) { // code also changed here, as the ref to classDistribution is incorrect for a tree // part1 += (double) sumsSquared[i] / classDistribution.get((double) i);//.data[i].size(); part1 += (double) sumsSquared[i] / classValuesArrayCounts[i];//.data[i].size(); part2 += sums[i]; } ssAmoung = part1 - (part2 * part2) / numInstances; double ssWithin = ssTotal - ssAmoung; int dfAmoung = numClasses - 1; int dfWithin = numInstances - numClasses; double msAmoung = ssAmoung / dfAmoung; double msWithin = ssWithin / dfWithin; double f = msAmoung / msWithin; this.fStat = f; } private void calculateBestSplitPoint(TreeMap<Double, Integer> classDistribution) { Collections.sort(orderline); double lastDist = orderline.get(0).distance; double thisDist = -1; double bsfGain = -1; double threshold = -1; for (int i = 1; i < orderline.size(); i++) { thisDist = orderline.get(i).distance; if (i == 1 || thisDist != lastDist) { // check that threshold has moved(no point in sampling identical thresholds)- special case - if 0 and 1 are the same dist // count class instances below and above threshold TreeMap<Double, Integer> lessClasses = new TreeMap<Double, Integer>(); TreeMap<Double, Integer> greaterClasses = new TreeMap<Double, Integer>(); for (double j : classDistribution.keySet()) { lessClasses.put(j, 0); greaterClasses.put(j, 0); } int sumOfLessClasses = 0; int sumOfGreaterClasses = 0; //visit those below threshold for (int j = 0; j < i; j++) { double thisClassVal = orderline.get(j).classVal; int storedTotal = lessClasses.get(thisClassVal); storedTotal++; lessClasses.put(thisClassVal, storedTotal); sumOfLessClasses++; } //visit those above threshold for (int j = i; j < orderline.size(); j++) { double thisClassVal = orderline.get(j).classVal; int storedTotal = greaterClasses.get(thisClassVal); storedTotal++; greaterClasses.put(thisClassVal, storedTotal); sumOfGreaterClasses++; } int sumOfAllClasses = sumOfLessClasses + sumOfGreaterClasses; double parentEntropy = entropy(classDistribution); // calculate the info gain below the threshold double lessFrac = (double) sumOfLessClasses / sumOfAllClasses; double entropyLess = entropy(lessClasses); // calculate the info gain above the threshold double greaterFrac = (double) sumOfGreaterClasses / sumOfAllClasses; double entropyGreater = entropy(greaterClasses); double gain = parentEntropy - lessFrac * entropyLess - greaterFrac * entropyGreater; if (gain > bsfGain) { bsfGain = gain; threshold = (thisDist - lastDist) / 2 + lastDist; } } lastDist = thisDist; } this.splitThresh = threshold; } private static double entropy(TreeMap<Double, Integer> classDistributions) { if (classDistributions.size() == 1) { return 0; } double thisPart; double toAdd; int total = 0; for (Double d : classDistributions.keySet()) { total += classDistributions.get(d); } // to avoid NaN calculations, the individual parts of the entropy are calculated and summed. // i.e. if there is 0 of a class, then that part would calculate as NaN, but this can be caught and // set to 0. ArrayList<Double> entropyParts = new ArrayList<Double>(); for (Double d : classDistributions.keySet()) { thisPart = (double) classDistributions.get(d) / total; toAdd = -thisPart * Math.log10(thisPart) / Math.log10(2); if (Double.isNaN(toAdd)) { toAdd = 0; } entropyParts.add(toAdd); } double entropy = 0; for (int i = 0; i < entropyParts.size(); i++) { entropy += entropyParts.get(i); } return entropy; } public int getLength() { return this.content.length; } public int compareTo(Shapelet shapelet) { final int BEFORE = -1; final int EQUAL = 0; final int AFTER = 1; if (this.fStat != shapelet.fStat) { if (this.fStat > shapelet.fStat) { return BEFORE; } else { return AFTER; } } else if (this.content.length != shapelet.getLength()) { if (this.content.length < shapelet.getLength()) { return BEFORE; } else { return AFTER; } } else { return EQUAL; } } } private static class OrderLineObj implements Comparable<OrderLineObj> { private double distance; private double classVal; private OrderLineObj(double distance, double classVal) { this.distance = distance; this.classVal = classVal; } public int compareTo(OrderLineObj o) { if (this.distance < o.distance) { return -1; } else if (this.distance == o.distance) { return 0; } else { return 1; } } } }