package cz.cuni.lf1.lge.ThunderSTORM;
import cz.cuni.lf1.lge.ThunderSTORM.UI.GUI;
import cz.cuni.lf1.lge.ThunderSTORM.UI.Help;
import cz.cuni.lf1.lge.ThunderSTORM.UI.MacroParser;
import cz.cuni.lf1.lge.ThunderSTORM.estimators.PSF.Molecule;
import cz.cuni.lf1.lge.ThunderSTORM.estimators.PSF.MoleculeDescriptor;
import cz.cuni.lf1.lge.ThunderSTORM.estimators.PSF.MoleculeDescriptor.Units;
import cz.cuni.lf1.lge.ThunderSTORM.results.GenericTable;
import cz.cuni.lf1.lge.ThunderSTORM.results.IJGroundTruthTable;
import cz.cuni.lf1.lge.ThunderSTORM.results.IJResultsTable;
import cz.cuni.lf1.lge.ThunderSTORM.util.GridBagHelper;
import cz.cuni.lf1.lge.ThunderSTORM.util.MoleculeMatcher;
import cz.cuni.lf1.lge.ThunderSTORM.util.Pair;
import cz.cuni.lf1.lge.ThunderSTORM.util.VectorMath;
import cz.cuni.lf1.lge.ThunderSTORM.util.MacroUI.DialogStub;
import cz.cuni.lf1.lge.ThunderSTORM.util.MacroUI.ParameterKey;
import cz.cuni.lf1.lge.ThunderSTORM.util.MacroUI.ParameterTracker;
import cz.cuni.lf1.lge.ThunderSTORM.util.MacroUI.validators.DoubleValidatorFactory;
import ij.IJ;
import ij.measure.ResultsTable;
import ij.plugin.PlugIn;
import javax.swing.*;
import java.awt.*;
import java.util.*;
import java.util.List;
import static cz.cuni.lf1.lge.ThunderSTORM.util.MathProxy.*;
public class PerformanceEvaluationPlugIn implements PlugIn {
private int processingFrame;
private int frames;
@Override
public void run(String command) {
GUI.setLookAndFeel();
Units distUnits = Units.NANOMETER;
//
if("showGroundTruthTable".equals(command)) {
IJGroundTruthTable.getGroundTruthTable().show();
return;
}
if(!IJResultsTable.isResultsWindow() || !IJGroundTruthTable.isGroundTruthWindow()) {
IJ.error("Requires `" + IJResultsTable.IDENTIFIER + "` and `" + IJGroundTruthTable.IDENTIFIER + "` windows open!");
return;
}
//
try {
// Create and show the dialog
PerformanceEvaluationDialog dialog = new PerformanceEvaluationDialog();
if(MacroParser.isRanFromMacro()) {
dialog.getParams().readMacroOptions();
} else {
if(dialog.showAndGetResult() != JOptionPane.OK_OPTION) {
return;
}
}
if(IJGroundTruthTable.getGroundTruthTable().isEmpty()) {
ResultsTable rt = ResultsTable.getResultsTable();
rt.incrementCounter();
rt.addValue("Distance radius [" + distUnits.getLabel() + "]", dialog.getToleranceRadius());
rt.addValue("# of TP", 0);
rt.addValue("# of FP", IJResultsTable.getResultsTable().getRowCount());
rt.addValue("# of FN", 0);
rt.addValue("Jaccard index", 0);
rt.addValue("precision", 0);
rt.addValue("recall", 0);
rt.addValue("F1-measure", 0);
rt.addValue("RMSE x [" + distUnits.getLabel() + "]", 0);
rt.addValue("RMSE y [" + distUnits.getLabel() + "]", 0);
rt.addValue("RMSE lateral [" + distUnits.getLabel() + "]", 0);
rt.addValue("RMSE axial [" + distUnits.getLabel() + "]", 0);
rt.addValue("RMSE total [" + distUnits.getLabel() + "]", 0);
rt.show("Results");
return;
}
if(IJResultsTable.getResultsTable().isEmpty()) {
ResultsTable rt = ResultsTable.getResultsTable();
rt.incrementCounter();
rt.addValue("Distance radius [" + distUnits.getLabel() + "]", dialog.getToleranceRadius());
rt.addValue("# of TP", 0);
rt.addValue("# of FP", 0);
rt.addValue("# of FN", IJGroundTruthTable.getGroundTruthTable().getRowCount());
rt.addValue("Jaccard index", 0);
rt.addValue("precision", 0);
rt.addValue("recall", 0);
rt.addValue("F1-measure", 0);
rt.addValue("RMSE x [" + distUnits.getLabel() + "]", 0);
rt.addValue("RMSE y [" + distUnits.getLabel() + "]", 0);
rt.addValue("RMSE lateral [" + distUnits.getLabel() + "]", 0);
rt.addValue("RMSE axial [" + distUnits.getLabel() + "]", 0);
rt.addValue("RMSE total [" + distUnits.getLabel() + "]", 0);
rt.show("Results");
return;
}
runEvaluation(dialog.getFrameByFrame(), dialog.getEvaluationSpace().equals("xyz"), sqr(dialog.getToleranceRadius()), distUnits);
} catch (Exception ex) {
IJ.handleException(ex);
}
}
private void prepareResultsTable(Units units) {
try {
// insert the new columns before parallel processing starts
IJResultsTable rt = IJResultsTable.getResultsTable();
MoleculeDescriptor descriptor = rt.getDescriptor();
int lastIndex = rt.getRow(0).values.length;
descriptor.addParam(MoleculeDescriptor.LABEL_GROUND_TRUTH_ID, lastIndex, Units.UNITLESS);
descriptor.addParam(MoleculeDescriptor.LABEL_DISTANCE_TO_GROUND_TRUTH_XY, lastIndex+1, units);
descriptor.addParam(MoleculeDescriptor.LABEL_DISTANCE_TO_GROUND_TRUTH_Z, lastIndex+2, units);
descriptor.addParam(MoleculeDescriptor.LABEL_DISTANCE_TO_GROUND_TRUTH_XYZ, lastIndex+3, units);
} catch(Exception ex) {
//
}
}
private void runEvaluation(boolean frameByFrame, boolean threeD, double dist2Tol, Units distUnits) {
//
int cores = frameByFrame ? Runtime.getRuntime().availableProcessors() : 1;
MoleculeMatcherWorker [] matchers = new MoleculeMatcherWorker[cores];
Thread [] threads = new Thread[cores];
processingFrame = 1;
prepareResultsTable(distUnits);
try {
frames = frameByFrame ? getFrameCount() : 1;
// prepare the workers and allocate resources for all the threads
for (int c = 0, f_start = 0, f_end, f_inc = frames / cores; c < cores; c++) {
if ((c + 1) < cores) {
f_end = f_start + f_inc;
} else {
f_end = frames;
}
if (frameByFrame) {
matchers[c] = new MoleculeMatcherWorker(f_start, f_end, threeD, dist2Tol, distUnits);
} else {
matchers[c] = new MoleculeMatcherWorker(-1, -1, threeD, dist2Tol, distUnits);
}
threads[c] = new Thread(matchers[c]);
f_start = f_end + 1;
}
// start all the workers
for (int c = 0; c < cores; c++) {
threads[c].start();
}
// wait for all the workers to finish
int wait = 1000 / cores; // max 1s
boolean finished = false;
while (!finished) {
finished = true;
for (int c = 0; c < cores; c++) {
threads[c].join(wait);
finished &= !threads[c].isAlive(); // all threads must not be alive to finish!
}
if (IJ.escapePressed()) { // abort?
// stop the workers
for (int ci = 0; ci < cores; ci++) {
threads[ci].interrupt();
}
// wait so the message below is not overwritten by any of the threads
for (int ci = 0; ci < cores; ci++) {
threads[ci].join();
}
// show info and exit the plugin
IJ.showProgress(1.0);
IJ.showStatus("Operation has been aborted by user!");
return;
}
}
} catch (IndexOutOfBoundsException ex) {
IJ.showMessage("Column `frame` does not exist! Either fill the column, or don't use frame-by-frame evaluation.");
return;
} catch (InterruptedException ex) {
//
} finally {
IJ.showProgress(1.0);
IJ.showStatus("");
}
//
IJ.showStatus("Gathering results...");
//
Vector<Pair<Molecule,Molecule>> TP = new Vector<Pair<Molecule,Molecule>>();
double tp = 0.0, fp = 0.0, fn = 0.0;
for(MoleculeMatcherWorker matcher : matchers) {
tp += (double) matcher.TP.size();
fp += (double) matcher.FP.size();
fn += (double) matcher.FN.size();
TP.addAll(matcher.TP);
}
double jaccard = tp / (tp + fp + fn);
double precision = tp / (tp + fp);
double recall = tp / (tp + fn);
double F1 = 2 * precision * recall / (precision + recall);
double calcRMSEx = calcRMSEx(TP, distUnits);
double calcRMSEy = calcRMSEy(TP, distUnits);
double RMSExy = calcRMSExy(TP, distUnits);
double RMSEz = calcRMSEz(TP, distUnits);
double RMSExyz = calcRMSExyz(TP, distUnits);
//
ResultsTable rt = ResultsTable.getResultsTable();
rt.incrementCounter();
rt.addValue("Distance radius [" + distUnits.getLabel()+ "]", sqrt(dist2Tol));
rt.addValue("# of TP", tp);
rt.addValue("# of FP", fp);
rt.addValue("# of FN", fn);
rt.addValue("Jaccard index", jaccard);
rt.addValue("precision", precision);
rt.addValue("recall", recall);
rt.addValue("F1-measure", F1);
rt.addValue("RMSE x [" + distUnits.getLabel() + "]", calcRMSEx);
rt.addValue("RMSE y [" + distUnits.getLabel() + "]", calcRMSEy);
rt.addValue("RMSE lateral [" + distUnits.getLabel() + "]", RMSExy);
rt.addValue("RMSE axial [" + distUnits.getLabel() + "]", RMSEz);
rt.addValue("RMSE total [" + distUnits.getLabel() + "]", RMSExyz);
rt.show("Results");
//
IJResultsTable.getResultsTable().fireStructureChanged();
IJResultsTable.getResultsTable().fireDataChanged();
IJGroundTruthTable.getGroundTruthTable().fireDataChanged();
//
IJ.showStatus("");
}
private double calcRMSExyz(List<Pair<Molecule, Molecule>> pairs, Units units) {
double err_sum = 0.0;
for(Pair<Molecule,Molecule> pair : pairs) {
err_sum += sqrt(pair.first.dist2xyz(pair.second, units));
}
return (err_sum / (double)pairs.size());
}
private double calcRMSExy(List<Pair<Molecule, Molecule>> pairs, Units units) {
double err_sum = 0;
for(Pair<Molecule,Molecule> pair : pairs) {
err_sum += sqrt(pair.first.dist2xy(pair.second, units));
}
return (err_sum / (double)pairs.size());
}
private double calcRMSEz(List<Pair<Molecule, Molecule>> pairs, Units units) {
double err_sum = 0;
for(Pair<Molecule,Molecule> pair : pairs) {
err_sum += sqrt(pair.first.dist2z(pair.second, units));
}
return (err_sum / (double)pairs.size());
}
private double calcRMSEx(List<Pair<Molecule, Molecule>> pairs, Units units){
double err_sum = 0;
for(Pair<Molecule,Molecule> pair : pairs) {
err_sum += abs(pair.first.getX(units) - pair.second.getX(units));
}
return (err_sum / (double)pairs.size());
}
private double calcRMSEy(List<Pair<Molecule, Molecule>> pairs, Units units){
double err_sum = 0;
for(Pair<Molecule,Molecule> pair : pairs) {
err_sum += abs(pair.first.getY(units) - pair.second.getY(units));
}
return (err_sum / (double)pairs.size());
}
private synchronized void processingNewFrame(String message) {
IJ.showStatus(String.format(message, processingFrame, frames));
IJ.showProgress((double)(processingFrame) / (double)frames);
processingFrame++;
}
private int getFrameCount() {
return (int) max(VectorMath.max(IJResultsTable.getResultsTable().getColumnAsDoubles(MoleculeDescriptor.LABEL_FRAME)),
VectorMath.max(IJGroundTruthTable.getGroundTruthTable().getColumnAsDoubles(MoleculeDescriptor.LABEL_FRAME)));
}
// ------------------------------------------------------------------------
final class MoleculeMatcherWorker implements Runnable {
// <Frame #, List of Molecules>
public Map<Integer, List<Molecule>> detections;
public Map<Integer, List<Molecule>> groundTruth;
private final MoleculeMatcher matcher;
public SortedSet<Integer> frames;
public Vector<Pair<Molecule, Molecule>> TP; // <ground-truth, detection>
public Vector<Molecule> FP, FN;
public MoleculeMatcherWorker(int frameStart, int frameStop, boolean threeD, double dist2Thr, Units distUnits) {
this.frames = new TreeSet<Integer>();
this.detections = fillWithData(frameStart, frameStop, IJResultsTable.getResultsTable());
this.groundTruth = fillWithData(frameStart, frameStop, IJGroundTruthTable.getGroundTruthTable());
this.matcher = new MoleculeMatcher(threeD, dist2Thr, distUnits);
//
this.TP = new Vector<Pair<Molecule, Molecule>>();
this.FP = new Vector<Molecule>();
this.FN = new Vector<Molecule>();
}
@Override
public void run() {
for(Integer frame : frames) {
if(Thread.interrupted()) {
return;
}
processingNewFrame("ThunderSTORM is evaluating frame %d out of %d...");
matcher.matchMolecules(detections.get(frame), groundTruth.get(frame), TP, FP, FN);
}
}
private Map<Integer, List<Molecule>> fillWithData(int frameStart, int frameStop, GenericTable table) {
Map<Integer, List<Molecule>> framesMolList = new HashMap<Integer, List<Molecule>>();
for(int i = 0, im = table.getRowCount(); i < im; i++) {
Molecule mol = table.getRow(i);
int frame = -1;
if (frameStart >= 0 && frameStop >= 0) {
frame = (int) mol.getParam(MoleculeDescriptor.LABEL_FRAME);
}
if((frame < frameStart) || (frame > frameStop)) continue;
if(!framesMolList.containsKey(frame)) {
framesMolList.put(frame, new Vector<Molecule>());
frames.add(frame);
}
framesMolList.get(frame).add(mol);
}
return framesMolList;
}
}
//---------------GUI-----------------------
class PerformanceEvaluationDialog extends DialogStub {
ParameterKey.Double toleranceRadius;
ParameterKey.String evaluationSpace;
ParameterKey.Boolean frameByFrame;
public PerformanceEvaluationDialog() {
super(new ParameterTracker("thunderstorm.evaluation"), IJ.getInstance(), "ThunderSTORM: Performance evaluation");
toleranceRadius = params.createDoubleField("toleranceRadius", DoubleValidatorFactory.positiveNonZero(), 50.0);
evaluationSpace = params.createStringField("evaluationSpace", null, "xy");
frameByFrame = params.createBooleanField("framebyFrame", null, true);
}
public ParameterTracker getParams() {
return params;
}
public double getToleranceRadius() {
return toleranceRadius.getValue();
}
public String getEvaluationSpace() {
return evaluationSpace.getValue();
}
public boolean getFrameByFrame() {
return frameByFrame.getValue();
}
@Override
protected void layoutComponents() {
JTextField toleranceTextField = new JTextField(20);
toleranceRadius.registerComponent(toleranceTextField);
add(new JLabel("Pair molecules with tolerance in:"), GridBagHelper.leftCol());
ButtonGroup btnGroup = new ButtonGroup();
JRadioButton rbXY = new JRadioButton("xy");
JRadioButton rbXYZ = new JRadioButton("xyz");
btnGroup.add(rbXY);
btnGroup.add(rbXYZ);
params.registerComponent(evaluationSpace, btnGroup);
add(rbXY, GridBagHelper.rightCol());
add(Box.createHorizontalGlue(), GridBagHelper.leftCol());
add(rbXYZ, GridBagHelper.rightCol());
add(new JLabel("Tolerance radius [nm]:"), GridBagHelper.leftCol());
add(toleranceTextField, GridBagHelper.rightCol());
add(Box.createVerticalStrut(10), GridBagHelper.twoCols());
JCheckBox frameByFrameCheckbox = new JCheckBox("frame-by-frame evaluation");
params.registerComponent(frameByFrame, frameByFrameCheckbox);
add(Box.createHorizontalGlue(), GridBagHelper.leftCol());
add(frameByFrameCheckbox, GridBagHelper.rightCol());
JPanel buttons = new JPanel(new GridBagLayout());
buttons.add(createDefaultsButton());
buttons.add(Box.createHorizontalGlue(), new GridBagHelper.Builder()
.fill(GridBagConstraints.HORIZONTAL).weightx(1).build());
buttons.add(Help.createHelpButton(PerformanceEvaluationPlugIn.class));
buttons.add(createOKButton());
buttons.add(createCancelButton());
add(Box.createVerticalStrut(10), GridBagHelper.twoCols());
add(buttons, GridBagHelper.twoCols());
params.loadPrefs();
getRootPane().setBorder(BorderFactory.createEmptyBorder(10, 10, 10, 10));
setLocationRelativeTo(null);
setModal(true);
}
}
}