package weka.classifiers.lazy; import java.util.ArrayList; import java.util.Random; import java.util.TreeSet; import weka.classifiers.AbstractClassifier; import weka.classifiers.Classifier; import weka.classifiers.Evaluation; import weka.core.*; /** * * @author Aaron */ public class RandomizedSphereCover extends AbstractClassifier implements Randomizable{ private int alpha; private NormalizableDistance distanceFunc; private TreeSet<Instance> uncoveredCases; private Instances allCases; private ArrayList<Instance> T; private ArrayList<Sphere> sphereSet; private int randSeed=100; private Random random = new Random(randSeed); private boolean crossValidateAlpha=false; public RandomizedSphereCover() { crossValidate(true); distanceFunc = new EuclideanDistance(); } public RandomizedSphereCover(int a) { this.alpha = a; distanceFunc = new EuclideanDistance(); } public final void crossValidate(boolean b){ crossValidateAlpha=b; } @Override public void setSeed(int seed) { random.setSeed(seed); randSeed=seed; } @Override public int getSeed() { return randSeed; } //default distance function is Euclidean @Override public void buildClassifier(Instances inst){ if(crossValidateAlpha){ //This is a REALLY inefficient way to do this cross validation, it is just a first go // Spheres are recalculated for every single fold! double bestAccuracy=0; int maxAlpha=inst.numInstances()/10; RandomizedSphereCover r; int folds=10; for(int a=1;a<maxAlpha;a++){ //Eval r=new RandomizedSphereCover(1); try{ Evaluation e=new Evaluation(inst); e.crossValidateModel(r, inst, folds, random); double acc=e.correct()/inst.numInstances(); if(acc>bestAccuracy){ bestAccuracy=acc; this.alpha=a; } }catch(Exception e){ e.printStackTrace(); System.exit(0); } } } sphereSet = new ArrayList(); uncoveredCases = new TreeSet<Instance>(new InstanceComparator()); distanceFunc.setInstances(inst); allCases = inst; // uncoveredCases.addAll(i); for(int j=0;j<inst.numInstances();j++) uncoveredCases.add(inst.instance(j)); //add members of allCases to covered as their covered until allCases is empty. while(uncoveredCases.size()>0){ //randomly pick an instance int rand = (int)(random.nextDouble()*uncoveredCases.size()); Instance[] tempArray = new Instance[uncoveredCases.size()]; uncoveredCases.toArray(tempArray); Instance temp = tempArray[rand]; uncoveredCases.remove(temp); //find closest instance that is not the same class value. Instance edge = null; double distance = Double.MAX_VALUE; for(int j=0; j<allCases.numInstances();j++){ Instance temp2 =allCases.instance(j); double tempDist = distanceFunc.distance(temp,temp2); //if its in the sphere and isn't the same class. if((tempDist <= distance) && (temp.classValue() != temp2.classValue())){ distance = tempDist; edge = temp2; } } Sphere TempSphere = new Sphere(temp,distance); //find the instances that are covered by the sphere. //i feel i could do some optimization here because there ordered? //but there ordered with respect to each other and does that mean they'll //be close togerger. Who knows? //if(uncoveredCases.size()>0){ T= new ArrayList(); T.add(edge); //find all cases that are inside the sphere. for(int j=0; j<allCases.numInstances();j++){ Instance tempInst = allCases.instance(j); double tempDist = distanceFunc.distance(temp,tempInst); //if its in the sphere and isn't itself. if((tempDist <= distance) && (tempDist != 0)){ T.add(tempInst); } } //check the number of instances covered. if(T.size()>=alpha){ for(int j=0;j<T.size();j++){ //remove from uncovered Instance temp1 =T.get(j); uncoveredCases.remove(temp1); } sphereSet.add(TempSphere); } //} } } //returns the instances classValue if its inside its sphere. Else it retursn the closest sphere edge. @Override public double classifyInstance(Instance i) throws Exception{ int closestSphere =0; int closestCentre=-1; double previousDistance = Double.MAX_VALUE; if(sphereSet.size() > 0){ for(int j=0;j<sphereSet.size();j++){ Sphere temp = sphereSet.get(j); double distance = distanceFunc.distance(temp.getCentre(),i); //if its inside the sphere if(distance <= temp.getRadius()){ if(closestCentre!=-1){ if(distance < distanceFunc.distance(sphereSet.get(closestCentre).getCentre(),i)) closestCentre=j; } else closestCentre =j; //return sphereSet.get(j).getCentre().classValue(); } else if(distance-temp.getRadius() <= previousDistance){ previousDistance = distance-temp.getRadius(); closestSphere = j; } //if its not, then check which sphere edge is closest. } if(closestCentre!=-1) return sphereSet.get(closestCentre).getCentre().classValue(); else return sphereSet.get(closestSphere).getCentre().classValue(); } else throw new Exception("No Spheres in the set"); } public void setDistanceFunc(NormalizableDistance in){ distanceFunc =in; } public ArrayList<Sphere> getSphereSet(){ return sphereSet; } @Override public String getRevision() { throw new UnsupportedOperationException("Not supported yet."); } public static class Sphere { private Instance centre; private double radius; public Sphere(Instance c, double r){ this.centre =c; this.radius =r; } public Instance getCentre(){ return centre; } public double getRadius(){ return radius; } } public static void main(String[] args){ System.out.println(" Test harness not implemented"); } }