package pl.piotrsukiennik.whowhen.classification.impl.weka;
import pl.piotrsukiennik.whowhen.classification.impl.ILabelingClusterer;
import pl.piotrsukiennik.whowhen.classification.impl.LabelingClassificationResult;
import weka.clusterers.Clusterer;
import weka.clusterers.NumberOfClustersRequestable;
import weka.core.Instance;
import java.util.Arrays;
import java.util.List;
/**
* @author Piotr Sukiennik
*/
public class WekaLabelingClusterer implements ILabelingClusterer, Cloneable {
protected Clusterer clusterer;
private String requiredLabel;
@Override
public void setExpectedNumberClusters( int numberClusters ) {
if ( numberClusters > 0 && clusterer instanceof NumberOfClustersRequestable ) {
try {
( (NumberOfClustersRequestable) clusterer ).setNumClusters( numberClusters );
}
catch ( Exception e ) {
throw new RuntimeException( e );
}
}
}
public WekaLabelingClusterer( Clusterer clusterer ) {
this.clusterer = clusterer;
}
@Override
public LabelingClassificationResult getClassification( double[] vector ) {
return this.getClassification( vector, new String[] { } );
}
@Override
public LabelingClassificationResult getClassification( double[] vector, String... labels ) {
return getClassificationDistribution( vector )[0];
}
@Override
public LabelingClassificationResult[] getClassificationDistribution( double[] vector ) {
Instance instance = WekaUtil.toInstance( vector );
try {
double[] distribution = clusterer.distributionForInstance( instance );
LabelingClassificationResult[] labelingClassificationResults = new LabelingClassificationResult[distribution.length];
for ( int i = 0; i < distribution.length; i++ ) {
labelingClassificationResults[i] = new LabelingClassificationResult( "cluster" + i, distribution[i] );
}
Arrays.sort( labelingClassificationResults, LabelingClassificationResult.PROBABILITY_COMPARATOR );
return labelingClassificationResults;
}
catch ( Exception e ) {
e.printStackTrace();
}
return null;
}
@Override
public void train( List<double[]> data ) {
try {
clusterer.buildClusterer( WekaUtil.toInstances( data, getRequiredLabel() ) );
}
catch ( Exception e ) {
e.printStackTrace();
}
}
public void setRequiredLabel( String requiredLabel ) {
this.requiredLabel = requiredLabel;
}
@Override
public String getRequiredLabel() {
return requiredLabel;
}
@Override
public WekaLabelingClusterer clone() {
try {
WekaLabelingClusterer wekaLabelingClusterer = new WekaLabelingClusterer( clusterer.getClass().newInstance() );
wekaLabelingClusterer.setRequiredLabel( this.getRequiredLabel() );
return wekaLabelingClusterer;
}
catch ( Exception e ) {
throw new RuntimeException( e );
}
}
}