package edu.harvard.mcb.leschziner.classify; import java.util.Map; import com.hazelcast.core.MultiMap; import edu.harvard.mcb.leschziner.analyze.CrossCorrelator; import edu.harvard.mcb.leschziner.core.Particle; import edu.harvard.mcb.leschziner.distributed.DistributedProcessingTask; import edu.harvard.mcb.leschziner.storage.DefaultStorageEngine; import edu.harvard.mcb.leschziner.storage.StorageEngine; /** * A task that classifies a single particle against many references using * pearson cross correlation * * @author spartango * */ public class CrossCorClassifierTask extends DistributedProcessingTask { private static final long serialVersionUID = -5350862097468663627L; // Particle to be processed private final Particle target; // Minimum correlation necessary to include this particle in a class private final double matchThreshold; // Name of the map of classes (distributed many-to-one map) private final String classMapName; // Name of the cache of class averages (distributed map) private final String averagesMapName; // Name of the set of templates (distributed set) private final String templateSetName; /** * Builds a classification task to be executed in the future * * @param target * : particle to be classified * @param matchThreshold * : minimum correlation necessary to allow classification * @param classMapName * : name of map of classes (distributed) * @param averagesMapName * : name of map of averages (distributed) * @param templateSetName * : name of set of templates (distributed) * @param executorName * : name of executor which will run this task (distributed) */ public CrossCorClassifierTask(Particle target, double matchThreshold, String classMapName, String averagesMapName, String templateSetName, String executorName) { super(executorName); this.target = target; this.matchThreshold = matchThreshold; this.classMapName = classMapName; this.averagesMapName = averagesMapName; this.templateSetName = templateSetName; } /** * Do the classification */ @Override public void process() { // Pull up distributed maps StorageEngine storage = DefaultStorageEngine.getStorageEngine(); MultiMap<Long, Particle> classes = storage.getMultiMap(classMapName); Map<Long, Particle> classAverages = storage.getMap(averagesMapName); Map<Long, Particle> templates = storage.getMap(templateSetName); // Iterate through the templates, scoring pearson correlation. double bestCorrelation = 0; Long bestTemplateId = null; for (long templateId : templates.keySet()) { double score = CrossCorrelator.compare(target, templates.get(templateId)); // Select best correlation if (score > bestCorrelation) { bestCorrelation = score; bestTemplateId = templateId; } } // Add target particle to closest match, if its above the threshold if (bestTemplateId != null && bestCorrelation >= matchThreshold) { // Add to class classes.put(bestTemplateId, target); // Invalidate the class average cache classAverages.remove(bestTemplateId); } } }