package weka.core.shapelet;
import java.util.ArrayList;
import java.util.TreeMap;
import weka.core.Instances;
import weka.core.shapelet.QualityMeasures.ShapeletQualityMeasure;
/**
* Copyright: Anthony Bagnall
* @author u0318701
*/
public class BinaryShapelet extends Shapelet{
protected double splitThreshold;
public BinaryShapelet(double[] content) {
super(content);
splitThreshold = -1;
}
public BinaryShapelet(double[] content, int seriesId, int startPos, ShapeletQualityMeasure qualityChoice) {
super(content, seriesId, startPos, qualityChoice);
splitThreshold = -1;
}
public BinaryShapelet(double[] content, int seriesId, int startPos,
ShapeletQualityMeasure qualityChoice, double qualityValue)
{
super(content,seriesId,startPos,qualityChoice);
this.qualityValue = qualityValue;
this.splitThreshold = -1;
}
public double getSplitThreshold() {
return splitThreshold;
}
public void calcInfoGainAndThreshold(ArrayList<OrderLineObj> orderline, TreeMap<Double, Integer> classDistribution){
// for each split point, starting between 0 and 1, ending between end-1 and end
// addition: track the last threshold that was used, don't bother if it's the same as the last one
double lastDist = orderline.get(0).getDistance(); // must be initialised as not visited(no point breaking before any data!)
double thisDist = -1;
double bsfGain = -1;
double threshold = -1;
// check that there is actually a split point
// for example, if all
for(int i = 1; i < orderline.size(); i++){
thisDist = orderline.get(i).getDistance();
if(i==1 || thisDist != lastDist){ // check that threshold has moved(no point in sampling identical thresholds)- special case - if 0 and 1 are the same dist
// count class instances below and above threshold
TreeMap<Double, Integer> lessClasses = new TreeMap<Double, Integer>();
TreeMap<Double, Integer> greaterClasses = new TreeMap<Double, Integer>();
for(double j : classDistribution.keySet()){
lessClasses.put(j, 0);
greaterClasses.put(j, 0);
}
int sumOfLessClasses = 0;
int sumOfGreaterClasses = 0;
//visit those below threshold
for(int j = 0; j < i; j++){
double thisClassVal = orderline.get(j).getClassVal();
int storedTotal = lessClasses.get(thisClassVal);
storedTotal++;
lessClasses.put(thisClassVal, storedTotal);
sumOfLessClasses++;
}
//visit those above threshold
for(int j = i; j < orderline.size(); j++){
double thisClassVal = orderline.get(j).getClassVal();
int storedTotal = greaterClasses.get(thisClassVal);
storedTotal++;
greaterClasses.put(thisClassVal, storedTotal);
sumOfGreaterClasses++;
}
int sumOfAllClasses = sumOfLessClasses + sumOfGreaterClasses;
double parentEntropy = entropy(classDistribution);
// calculate the info gain below the threshold
double lessFrac =(double) sumOfLessClasses / sumOfAllClasses;
double entropyLess = entropy(lessClasses);
// calculate the info gain above the threshold
double greaterFrac =(double) sumOfGreaterClasses / sumOfAllClasses;
double entropyGreater = entropy(greaterClasses);
double gain = parentEntropy - lessFrac * entropyLess - greaterFrac * entropyGreater;
// System.out.println(parentEntropy+" - "+lessFrac+" * "+entropyLess+" - "+greaterFrac+" * "+entropyGreater);
// System.out.println("gain calc:"+gain);
if(gain > bsfGain){
bsfGain = gain;
threshold =(thisDist - lastDist) / 2 + lastDist;
}
}
lastDist = thisDist;
}
if(bsfGain >= 0){
this.splitThreshold = threshold;
}
}
private static double entropy(TreeMap<Double, Integer> classDistributions){
if(classDistributions.size() == 1){
return 0;
}
double thisPart;
double toAdd;
int total = 0;
for(Double d : classDistributions.keySet()){
total += classDistributions.get(d);
}
// to avoid NaN calculations, the individual parts of the entropy are calculated and summed.
// i.e. if there is 0 of a class, then that part would calculate as NaN, but this can be caught and
// set to 0.
ArrayList<Double> entropyParts = new ArrayList<Double>();
for(Double d : classDistributions.keySet()){
thisPart =(double) classDistributions.get(d) / total;
toAdd = -thisPart * Math.log10(thisPart) / Math.log10(2);
if(Double.isNaN(toAdd))
toAdd=0;
entropyParts.add(toAdd);
}
double entropy = 0;
for(int i = 0; i < entropyParts.size(); i++){
entropy += entropyParts.get(i);
}
return entropy;
}
}