package cz.cuni.lf1.lge.ThunderSTORM.drift;
import cz.cuni.lf1.lge.ThunderSTORM.UI.GUI;
import cz.cuni.lf1.lge.ThunderSTORM.results.*;
import cz.cuni.lf1.lge.ThunderSTORM.estimators.PSF.Molecule;
import cz.cuni.lf1.lge.ThunderSTORM.estimators.PSF.MoleculeDescriptor;
import cz.cuni.lf1.lge.ThunderSTORM.estimators.PSF.PSFModel;
import cz.cuni.lf1.lge.ThunderSTORM.estimators.optimizers.NelderMead;
import cz.cuni.lf1.lge.ThunderSTORM.util.IJProgressTracker;
import cz.cuni.lf1.lge.ThunderSTORM.util.MathProxy;
import cz.cuni.lf1.lge.ThunderSTORM.util.VectorMath;
import ij.IJ;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.commons.math3.analysis.MultivariateFunction;
import org.apache.commons.math3.analysis.polynomials.PolynomialSplineFunction;
import org.apache.commons.math3.util.MathArrays;
public class FiducialDriftEstimator {
public DriftResults estimateDrift(List<Molecule> molecules, double distanceThr, double onTimeRatio, double smoothingBandwidth) {
int minFrame = (int) getMinFrame(molecules);
int maxFrame = (int) getMaxFrame(molecules);
//group molecules appearing in subsequent frames
IJ.showStatus("Grouping molecules...");
List<Molecule> groupedMolecules = groupMolecules(molecules, distanceThr);
//select fiducial markers (molecules that are on for many frames)
List<Molecule> fiducialMarkers = new ArrayList<Molecule>();
for(Molecule mol : groupedMolecules) {
if(mol.getParam(MoleculeDescriptor.LABEL_DETECTIONS) > onTimeRatio * (maxFrame - minFrame)) {
fiducialMarkers.add(mol);
}
}
if(fiducialMarkers.isEmpty()) {
throw new RuntimeException("No fiducial markers found.");
}
//combine data points from multiple fiducial markers
int dataPoints = countDetections(fiducialMarkers);
double[] combinedFrames = new double[dataPoints];
double[] combinedX = new double[dataPoints];
double[] combinedY = new double[dataPoints];
//combine frame data values
int lastIndex = 0;
for(Molecule mol : fiducialMarkers) {
List<Molecule> detections = mol.getDetections();
double[] frame = extractParamAsArray(detections, detections.get(0).descriptor.getParamIndex(MoleculeDescriptor.LABEL_FRAME));
System.arraycopy(frame, 0, combinedFrames, lastIndex, frame.length);
lastIndex += frame.length;
}
//find offsets for each fiducial marker (to get relative drift out of absolute coordinates)
IJ.showStatus("Finding marker offsets (x)...");
double[] markerOffsetsInX = findFiducialsOffsets(fiducialMarkers, combinedFrames, PSFModel.Params.LABEL_X);
IJ.showProgress(0.875);
GUI.checkIJEscapePressed();
IJ.showStatus("Finding marker offsets (y)...");
double[] markerOffsetsInY = findFiducialsOffsets(fiducialMarkers, combinedFrames, PSFModel.Params.LABEL_Y);
IJ.showProgress(0.95);
GUI.checkIJEscapePressed();
//combine x,y, while subtracting the found offsets
lastIndex = 0;
for(int i = 0; i < fiducialMarkers.size(); i++) {
List<Molecule> detections = fiducialMarkers.get(i).getDetections();
double[] x = extractParamAsArray(detections, detections.get(0).descriptor.getParamIndex(PSFModel.Params.LABEL_X));
double[] y = extractParamAsArray(detections, detections.get(0).descriptor.getParamIndex(PSFModel.Params.LABEL_Y));
VectorMath.add(x, -markerOffsetsInX[i]);
VectorMath.add(y, -markerOffsetsInY[i]);
System.arraycopy(x, 0, combinedX, lastIndex, x.length);
System.arraycopy(y, 0, combinedY, lastIndex, y.length);
lastIndex += x.length;
}
//sort, because loess interpolation needs non descending domain values
MathArrays.sortInPlace(combinedFrames, combinedX, combinedY);
//subtract first frame coordinates so that drift at first frame is zero
//Could be a problem when first frame drift is off. ??
VectorMath.add(combinedX, -combinedX[0]);
VectorMath.add(combinedY, -combinedY[0]);
//smooth & interpolate
IJ.showStatus("Smoothing and interpolating drift...");
ModifiedLoess interpolator = new ModifiedLoess(smoothingBandwidth, 0);
PolynomialSplineFunction xFunction = CorrelationDriftEstimator.addLinearExtrapolationToBorders(interpolator.interpolate(combinedFrames, combinedX), minFrame, maxFrame);
PolynomialSplineFunction yFunction = CorrelationDriftEstimator.addLinearExtrapolationToBorders(interpolator.interpolate(combinedFrames, combinedY), minFrame, maxFrame);
IJ.showProgress(1);
//same units as input
MoleculeDescriptor.Units units = molecules.get(0).getParamUnits(PSFModel.Params.LABEL_X);
return new DriftResults(xFunction, yFunction, combinedFrames, combinedX, combinedY, minFrame, maxFrame, units);
}
private int countDetections(List<Molecule> fiducialMarkers) {
int dataPoints = 0;
for(Molecule mol : fiducialMarkers) {
dataPoints += mol.getDetections().size();
}
return dataPoints;
}
private List<Molecule> groupMolecules(List<Molecule> molecules, double distanceThr) {
FrameSequence grouping = new FrameSequence();
for(Molecule mol : molecules) {
grouping.InsertMolecule(mol);
}
IJProgressTracker tracker = new IJProgressTracker(0, 0.8);
grouping.matchMolecules(MathProxy.sqr(distanceThr),
new FrameSequence.RelativeToDetectionCount(2),
new FrameSequence.LastFewDetectionsMean(5),
0,
tracker);
List<Molecule> groupedMolecules = grouping.getAllMolecules();
return groupedMolecules;
}
public double[] findFiducialsOffsets(List<Molecule> fiducials, double[] combinedFrames, String param) {
//first, restructure the required data in a data structure that can be efficiently used in the optimization process
//a helper class that holds a detection coordinate and an index of fiducial marker the detection belongs to
class ValAndMarkerIndex {
double val;
int index;
public ValAndMarkerIndex(double val, int index) {
this.val = val;
this.index = index;
}
}
//create a map from frame to a list of fiducial marker detections in that frame
Map<Double, List<ValAndMarkerIndex>> values = new HashMap<Double, List<ValAndMarkerIndex>>();
for(int i = 0; i < fiducials.size(); i++) {
Molecule fiducial = fiducials.get(i);
for(Molecule detection : fiducial.getDetections()) {
double frame = detection.getParam(MoleculeDescriptor.LABEL_FRAME);
List<ValAndMarkerIndex> list = values.get(frame);
if(list == null) {
list = new ArrayList<ValAndMarkerIndex>();
values.put(frame, list);
}
list.add(new ValAndMarkerIndex(detection.getParam(param), i));
}
}
//prune frames with less than two detections
for(Iterator<Map.Entry<Double, List<ValAndMarkerIndex>>> it = values.entrySet().iterator(); it.hasNext();) {
List<ValAndMarkerIndex> list = it.next().getValue();
if(list.size() < 2) {
it.remove();
}
}
//copy the values collection to a list
//this is the final data structure used in optimization
final List<List<ValAndMarkerIndex>> detectionsInFrames = new ArrayList<List<ValAndMarkerIndex>>(values.values());
NelderMead nm = new NelderMead();
//cost function:
//for each frame where multiple drift values are present
// cost += square of difference between each drift value and mean drift value for that frame
MultivariateFunction fun = new MultivariateFunction() {
@Override
public double value(double[] point) {
GUI.checkIJEscapePressed();
double cost = 0;
for(List<ValAndMarkerIndex> oneFrameDetections : detectionsInFrames) {
double mean = 0;
for(ValAndMarkerIndex detection : oneFrameDetections) {
mean += detection.val - point[detection.index];
}
mean /= oneFrameDetections.size();
for(ValAndMarkerIndex detection : oneFrameDetections) {
cost += MathProxy.sqr(detection.val - point[detection.index] - mean);
}
}
return Math.sqrt(cost);
}
};
//values for first iteration: first detection coords
double[] guess = new double[fiducials.size()];
for(int i = 0; i < guess.length; i++) {
guess[i] = fiducials.get(i).getDetections().get(0).getParam(param);
}
//first simplex step size, ?????
double[] initialSimplex = new double[fiducials.size()];
Arrays.fill(initialSimplex, 50);
int maxIter = 5000;
nm.optimize(fun, NelderMead.Objective.MINIMIZE, guess, 1e-8, initialSimplex, 10, maxIter);
double[] fittedParameters = nm.xmin;
return fittedParameters;
}
static double[] extractParamAsArray(List<Molecule> mols, int index) {
double[] arr = new double[mols.size()];
for(int i = 0; i < mols.size(); i++) {
arr[i] = mols.get(i).getParamAt(index);
}
return arr;
}
private double getMinFrame(List<Molecule> molecules) {
double min = molecules.get(0).getParam(MoleculeDescriptor.LABEL_FRAME);
for(Molecule mol : molecules) {
double frame = mol.getParam(MoleculeDescriptor.LABEL_FRAME);
if(frame < min) {
min = frame;
}
}
return min;
}
private double getMaxFrame(List<Molecule> molecules) {
double max = molecules.get(0).getParam(MoleculeDescriptor.LABEL_FRAME);
for(Molecule mol : molecules) {
double frame = mol.getParam(MoleculeDescriptor.LABEL_FRAME);
if(frame > max) {
max = frame;
}
}
return max;
}
}