package pl.piotrsukiennik.whowhen.classification.impl; import pl.piotrsukiennik.whowhen.classification.label.LabelledIndexIntervals; import pl.piotrsukiennik.whowhen.classification.smooth.Smoother; import java.io.File; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; /** * @author Piotr Sukiennik */ public class ClassifierChain { private File classifierDataDirectory; private ILabelingClassifier[] classifiers; private ILabelingClusterer clutererTemplate; public ClassifierChain( ILabelingClassifier[] classifiers, ILabelingClusterer clutererTemplate, File classifierDataDirectory ) { this.classifierDataDirectory = classifierDataDirectory; this.classifiers = classifiers; this.clutererTemplate = clutererTemplate; } public List<Map<String, List<int[]>>> process( List<double[]> features ) { return process( features, null ); } public List<Map<String, List<int[]>>> process( List<double[]> features, Integer expectedClusterCount ) { Map<String, List<double[]>> classifiedFeatures = new HashMap<String, List<double[]>>(); String[][] labelsAppliedToFeatureVectors = new String[features.size()][classifiers.length]; LabelledIndexIntervals[] labelledIndexIntervals = applyClassifing( features, classifiedFeatures, labelsAppliedToFeatureVectors ); List<Map<String, List<int[]>>> leveledLabelledIntervals = new ArrayList<Map<String, List<int[]>>>(); for ( LabelledIndexIntervals labelledIntervals : labelledIndexIntervals ) { leveledLabelledIntervals.add( labelledIntervals.getLabelledIndexRanges() ); } long from = System.currentTimeMillis(); if ( clutererTemplate != null && ( expectedClusterCount == null || expectedClusterCount != 0 ) ) { ILabelingClusterer clusterer_Cloned = clutererTemplate.clone(); if ( expectedClusterCount != null ) { clusterer_Cloned.setExpectedNumberClusters( expectedClusterCount ); } LabelledIndexIntervals clustersLabelledIndexIntervals = applyClustering( clusterer_Cloned, features, classifiedFeatures, labelsAppliedToFeatureVectors ); leveledLabelledIntervals.add( clustersLabelledIndexIntervals.getLabelledIndexRanges() ); } long to = System.currentTimeMillis(); System.out.println( "CLUSTERING TOOK: " + ( to - from ) ); return leveledLabelledIntervals; } public LabelledIndexIntervals applyClustering( ILabelingClusterer clusterer, List<double[]> features, Map<String, List<double[]>> classifiedFeatures, String[][] labelsAppliedToFeatureVectors ) { trainClusters( clusterer, classifiedFeatures ); Smoother smoother = new Smoother(); smoother.setRange( 5 ); for ( int i = 0; i < features.size(); i++ ) { for ( String label : labelsAppliedToFeatureVectors[i] ) { if ( clusterer.getRequiredLabel().equals( label ) ) { LabelingClassificationResult classificationResult = clusterer.getClassification( features.get( i ) ); List<double[]> classifiedFeaturesByCluster; if ( ( classifiedFeaturesByCluster = classifiedFeatures.get( classificationResult.getLabel() ) ) == null ) { classifiedFeaturesByCluster = new ArrayList<double[]>(); classifiedFeatures.put( classificationResult.getLabel(), classifiedFeaturesByCluster ); } smoother.submit( classificationResult.getLabel() ); } else { smoother.submit( null ); } } } LabelledIndexIntervals clustersLabelledIndexIntervals = new LabelledIndexIntervals(); List<String> correctedLabels = smoother.getWithCorrectionApplied(); for ( int i = 0; i < correctedLabels.size(); i++ ) { String label; if ( ( label = correctedLabels.get( i ) ) != null ) { clustersLabelledIndexIntervals.applyLabel( label, i ); } } return clustersLabelledIndexIntervals; } public LabelledIndexIntervals[] applyClassifing( List<double[]> features, Map<String, List<double[]>> classifiedFeatures, String[][] labelsAppliedToFeatureVectors ) { LabelledIndexIntervals[] labelledIndexIntervals = new LabelledIndexIntervals[classifiers.length]; for ( int i = 0; i < labelledIndexIntervals.length; i++ ) { labelledIndexIntervals[i] = new LabelledIndexIntervals(); } for ( int featuresIndex = 0; featuresIndex < features.size(); featuresIndex++ ) { labelsAppliedToFeatureVectors[featuresIndex] = new String[1]; for ( int i = 0; i < classifiers.length; i++ ) { ILabelingClassifier labelingClassifier = classifiers[i]; LabelingClassificationResult labelingClassificationResult = labelingClassifier.getClassification( features.get( featuresIndex ), labelsAppliedToFeatureVectors[featuresIndex] ); if ( labelingClassificationResult != null ) { labelledIndexIntervals[i].applyLabel( labelingClassificationResult.getLabel(), featuresIndex ); List<double[]> cachedFeaturesList = classifiedFeatures.get( labelingClassificationResult.getLabel() ); if ( cachedFeaturesList == null ) { cachedFeaturesList = new ArrayList<double[]>(); classifiedFeatures.put( labelingClassificationResult.getLabel(), cachedFeaturesList ); } cachedFeaturesList.add( features.get( featuresIndex ) ); labelsAppliedToFeatureVectors[featuresIndex][i] = labelingClassificationResult.getLabel(); } } } return labelledIndexIntervals; } public void trainClusters( ILabelingClusterer clusterer, Map<String, List<double[]>> classifiedFeatures ) { if ( classifiedFeatures.get( clusterer.getRequiredLabel() ) != null ) { clusterer.train( classifiedFeatures.get( clusterer.getRequiredLabel() ) ); } } public File getClassifierDataDirectory() { return classifierDataDirectory; } public void setClassifierDataDirectory( File classifierDataDirectory ) { this.classifierDataDirectory = classifierDataDirectory; } }