/*
A binary filter that uses information gain quality measure to determine the split point/
* copyright: Anthony Bagnall
*/
package weka.filters.timeseries;
//import java.io.FileWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.TreeMap;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.shapelet.OrderLineObj;
import weka.core.shapelet.QualityMeasures.ShapeletQualityChoice;
import weka.core.shapelet.*;
import weka.filters.SimpleBatchFilter;
import weka.filters.timeseries.shapelet_transforms.FullShapeletTransform;
/**
*
* @author Jon Hills j.hills@uea.ac.uk
*/
public class BinaryTransform extends SimpleBatchFilter{
private boolean findNewSplits=true;
private double[] splits;
public void findNewSplits(){findNewSplits=true;}
@Override
protected Instances determineOutputFormat(Instances inputFormat) throws Exception{
//Check all are numerical
//Check all attributes are real valued, otherwise throw exception
for(int i=0;i<inputFormat.numAttributes();i++)
if(inputFormat.classIndex()!=i)
if(!inputFormat.attribute(i).isNumeric())
throw new Exception("Non numeric attribute not allowed in BinaryTransform");
int length=inputFormat.numAttributes();
if(inputFormat.classIndex()>=0)
length--;
//Set up instances size and format.
FastVector atts=new FastVector();
FastVector attributeValues=new FastVector();
attributeValues.addElement("0");
attributeValues.addElement("1");
String name;
for(int i=0;i<length;i++){
name = "Binary_"+i;
atts.addElement(new Attribute(name,attributeValues));
}
if(inputFormat.classIndex()>=0){ //Classification set, set class
//Get the class values as a fast vector
Attribute target =inputFormat.attribute(inputFormat.classIndex());
FastVector vals=new FastVector(target.numValues());
for(int i=0;i<target.numValues();i++)
vals.addElement(target.value(i));
atts.addElement(new Attribute(inputFormat.attribute(inputFormat.classIndex()).name(),vals));
}
Instances result = new Instances("Binary"+inputFormat.relationName(),atts,inputFormat.numInstances());
if(inputFormat.classIndex()>=0){
result.setClassIndex(result.numAttributes()-1);
}
return result;
}
@Override
public Instances process(Instances data) throws Exception{
Instances output = determineOutputFormat(data);
if(findNewSplits){
splits=new double[data.numAttributes()];
double[] classes=new double[data.numInstances()];
for(int i=0;i<classes.length;i++)
classes[i]=data.instance(i).classValue();
for (int j=0; j< data.numAttributes(); j++) { // for each data
if(j!=data.classIndex()){
//Get values of attribute j
double[] vals=new double[data.numInstances()];
for(int i=0;i<data.numInstances();i++)
vals[i]=data.instance(i).value(j);
//find the IG split point
splits[j] =findSplitValue(data,vals,classes);
}
}
findNewSplits=false;
}
//Extract out the terms and set the attributes
for(int i=0;i<data.numInstances();i++){
Instance newInst=new DenseInstance(data.numAttributes());
for(int j=0;j<data.numAttributes();j++){
if(j!=data.classIndex()){
if(data.instance(i).value(j)<splits[j])
newInst.setValue(j,0);
else
newInst.setValue(j,1);
}
else
newInst.setValue(j,data.instance(i).classValue());
}
output.add(newInst);
}
return output;
}
public double findSplitValue(Instances data, double[] vals, double[] classes){
// return 1;
//Put into an order list
ArrayList<OrderLineObj> list=new ArrayList<OrderLineObj>();
for(int i=0;i<vals.length;i++)
list.add(new OrderLineObj(vals[i],classes[i]));
//Sort the vals
TreeMap<Double,Integer> tree = FullShapeletTransform.getClassDistributions(data);
Collections.sort(list);
return infoGainThreshold(list,tree);
}
private static double entropy(TreeMap<Double, Integer> classDistributions){
if(classDistributions.size() == 1){
return 0;
}
double thisPart;
double toAdd;
int total = 0;
for(Double d : classDistributions.keySet()){
total += classDistributions.get(d);
}
// to avoid NaN calculations, the individual parts of the entropy are calculated and summed.
// i.e. if there is 0 of a class, then that part would calculate as NaN, but this can be caught and
// set to 0.
ArrayList<Double> entropyParts = new ArrayList<Double>();
for(Double d : classDistributions.keySet()){
thisPart =(double) classDistributions.get(d) / total;
toAdd = -thisPart * Math.log10(thisPart) / Math.log10(2);
if(Double.isNaN(toAdd))
toAdd=0;
entropyParts.add(toAdd);
}
double entropy = 0;
for(int i = 0; i < entropyParts.size(); i++){
entropy += entropyParts.get(i);
}
return entropy;
}
public static double infoGainThreshold(ArrayList<OrderLineObj> orderline, TreeMap<Double, Integer> classDistribution){
// for each split point, starting between 0 and 1, ending between end-1 and end
// addition: track the last threshold that was used, don't bother if it's the same as the last one
double lastDist = orderline.get(0).getDistance(); // must be initialised as not visited(no point breaking before any data!)
double thisDist = -1;
double bsfGain = -1;
double threshold = -1;
// check that there is actually a split point
// for example, if all
for(int i = 1; i < orderline.size(); i++){
thisDist = orderline.get(i).getDistance();
if(i==1 || thisDist != lastDist){ // check that threshold has moved(no point in sampling identical thresholds)- special case - if 0 and 1 are the same dist
// count class instances below and above threshold
TreeMap<Double, Integer> lessClasses = new TreeMap<Double, Integer>();
TreeMap<Double, Integer> greaterClasses = new TreeMap<Double, Integer>();
for(double j : classDistribution.keySet()){
lessClasses.put(j, 0);
greaterClasses.put(j, 0);
}
int sumOfLessClasses = 0;
int sumOfGreaterClasses = 0;
//visit those below threshold
for(int j = 0; j < i; j++){
double thisClassVal = orderline.get(j).getClassVal();
int storedTotal = lessClasses.get(thisClassVal);
storedTotal++;
lessClasses.put(thisClassVal, storedTotal);
sumOfLessClasses++;
}
//visit those above threshold
for(int j = i; j < orderline.size(); j++){
double thisClassVal = orderline.get(j).getClassVal();
int storedTotal = greaterClasses.get(thisClassVal);
storedTotal++;
greaterClasses.put(thisClassVal, storedTotal);
sumOfGreaterClasses++;
}
int sumOfAllClasses = sumOfLessClasses + sumOfGreaterClasses;
double parentEntropy = entropy(classDistribution);
// calculate the info gain below the threshold
double lessFrac =(double) sumOfLessClasses / sumOfAllClasses;
double entropyLess = entropy(lessClasses);
// calculate the info gain above the threshold
double greaterFrac =(double) sumOfGreaterClasses / sumOfAllClasses;
double entropyGreater = entropy(greaterClasses);
double gain = parentEntropy - lessFrac * entropyLess - greaterFrac * entropyGreater;
// System.out.println(parentEntropy+" - "+lessFrac+" * "+entropyLess+" - "+greaterFrac+" * "+entropyGreater);
// System.out.println("gain calc:"+gain);
if(gain > bsfGain){
bsfGain = gain;
threshold =(thisDist - lastDist) / 2 + lastDist;
}
}
lastDist = thisDist;
}
return threshold;
}
@Override
public String globalInfo() {
throw new UnsupportedOperationException("Not supported yet.");
}
}