package tr.gov.ulakbim.jDenetX.core;
import weka.clusterers.SimpleKMeans;
import weka.core.EuclideanDistance;
import weka.core.Instance;
import weka.core.Instances;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Remove;
import java.util.ArrayList;
import java.util.Collections;
/**
* Created by IntelliJ IDEA.
* User: caglar
* Date: Sep 3, 2010
* Time: 11:27:12 AM
* To change this template use File | Settings | File Templates.
*/
public class ClusterTrainingDataHarvester {
private int NoOfClusters = 0;
private Instances Centroids;
private static DoubleVector Weights = new DoubleVector();
public ClusterTrainingDataHarvester(int noOfClusters) {
NoOfClusters = noOfClusters;
}
//Put the closest cluster first
public Instances[] getEnsembleTrainingData(Instances classInstances, int noOfClasses) {
DoubleVector weights = new DoubleVector();
Instances[] resultData = new Instances[NoOfClusters];
ArrayList attList = Collections.list(classInstances.enumerateAttributes());
attList.add(classInstances.classAttribute());
try {
for (int i = 0; i < resultData.length; i++) {
resultData[i] = new Instances("resultData", attList, classInstances.size());
resultData[i].setClassIndex(attList.size() - 1);
}
Instance classCentroid = getClassCentroid(classInstances);
EuclideanDistance distanceFun = new EuclideanDistance();
int[] assignments = getClusterAssignments(classInstances);
int[] pointList = new int[Centroids.size()];
int[] orderedClusterList = new int[Centroids.size()];
Instances classCentroids = (Instances) MiscUtils.deepCopy(Centroids);
int idx = 0;
distanceFun.setInstances(classCentroids);
double distance;
for (int i = 0; i < Centroids.size(); i++) {
distance = distanceFun.distance(classCentroid, classCentroids.get(i));
weights.setValue(i, distance);
}
weights.normalize();
for (int i = 0; i < Centroids.size(); i++) {
pointList[i] = i;
}
//TODO: Fix the bug here
//When no points left in the pointlist, this code may throw an ArrayIndexOutOfBounds exception
for (int i = 0; i < Centroids.size(); i++) {
idx = distanceFun.closestPoint(classCentroid, classCentroids, pointList);
orderedClusterList[idx] = i;
pointList = MiscUtils.removeIntElement(pointList, idx);
}
for (int i = 0; i < classInstances.size(); i++) {
resultData[orderedClusterList[assignments[i]]].add(classInstances.get(i));
}
if (Weights.numValues() > 0) {
//Get average of them
Weights.addValues(weights);
Weights.scaleValues(0.5);
} else {
Weights = (DoubleVector) weights.copy();
}
} catch (Exception e) {
e.printStackTrace(); //To change body of catch statement use File | Settings | File Templates.
}
return resultData;
}
private Instance getClassCentroid(Instances classInstances) throws Exception {
SimpleKMeans kmeans;
Instances centroids;
Instances filteredData;
Remove filter;
Instances insts = (Instances) MiscUtils.deepCopy(classInstances);
kmeans = new SimpleKMeans();
filter = new Remove();
insts.setClassIndex(insts.numAttributes() - 1);
filter.setAttributeIndices("" + (insts.classIndex() + 1));
filter.setInputFormat(insts);
filteredData = Filter.useFilter(insts, filter);
kmeans.setNumClusters(1);
kmeans.setMaxIterations(500);
kmeans.buildClusterer(filteredData);
return kmeans.getClusterCentroids().firstInstance();
}
private int[] getClusterAssignments(Instances classInstances) throws Exception {
SimpleKMeans kmeans;
Instances centroids;
Instances filteredData;
Remove filter;
Instances insts = (Instances) MiscUtils.deepCopy(classInstances);
kmeans = new SimpleKMeans();
filter = new Remove();
insts.setClassIndex(classInstances.numAttributes() - 1);
filter.setAttributeIndices("" + (insts.classIndex() + 1));
filter.setInputFormat(insts);
filteredData = Filter.useFilter(insts, filter);
kmeans.setNumClusters(NoOfClusters);
kmeans.setMaxIterations(500);
kmeans.setPreserveInstancesOrder(true);
kmeans.buildClusterer(filteredData);
int[] assignments = kmeans.getAssignments();
Centroids = kmeans.getClusterCentroids();
return assignments;
}
public DoubleVector getWeights() {
return Weights;
}
public void clearWeights() {
Weights = new DoubleVector();
}
}