/*-
* #%L
* Fiji distribution of ImageJ for the life sciences.
* %%
* Copyright (C) 2007 - 2017 Fiji developers.
* %%
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as
* published by the Free Software Foundation, either version 2 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public
* License along with this program. If not, see
* <http://www.gnu.org/licenses/gpl-2.0.html>.
* #L%
*/
package spim.process.fusion.deconvolution;
import ij.ImagePlus;
import java.io.File;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Vector;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import mpicbg.spim.data.sequence.Angle;
import mpicbg.spim.data.sequence.Channel;
import mpicbg.spim.data.sequence.Illumination;
import mpicbg.spim.data.sequence.TimePoint;
import mpicbg.spim.data.sequence.ViewDescription;
import mpicbg.spim.data.sequence.ViewId;
import mpicbg.spim.io.IOFunctions;
import net.imglib2.Cursor;
import net.imglib2.FinalInterval;
import net.imglib2.Interval;
import net.imglib2.RandomAccess;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.img.Img;
import net.imglib2.img.ImgFactory;
import net.imglib2.img.display.imagej.ImageJFunctions;
import net.imglib2.realtransform.AffineTransform3D;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.view.Views;
import spim.Threads;
import spim.fiji.ImgLib2Temp.Pair;
import spim.fiji.spimdata.SpimData2;
import spim.fiji.spimdata.ViewSetupUtils;
import spim.fiji.spimdata.imgloaders.LegacyStackImgLoaderIJ;
import spim.fiji.spimdata.interestpoints.CorrespondingInterestPoints;
import spim.fiji.spimdata.interestpoints.InterestPoint;
import spim.fiji.spimdata.interestpoints.InterestPointList;
import spim.process.fusion.FusionHelper;
import spim.process.fusion.ImagePortion;
import spim.process.fusion.boundingbox.BoundingBoxGUI;
import spim.process.fusion.export.DisplayImage;
import spim.process.fusion.weightedavg.ProcessFusion;
import spim.process.fusion.weights.Blending;
import spim.process.fusion.weights.NormalizingRandomAccessibleInterval;
import spim.process.fusion.weights.TransformedRealRandomAccessibleInterval;
import bdv.util.ConstantRandomAccessible;
/**
* Fused individual images for each input stack, uses the exporter directly
*
* @author Stephan Preibisch (stephan.preibisch@gmx.de)
*
*/
public class ProcessForDeconvolution
{
public static enum WeightType { WEIGHTS_ONLY, NO_WEIGHTS, VIRTUAL_WEIGHTS, PRECOMPUTED_WEIGHTS, LOAD_WEIGHTS };
final protected SpimData2 spimData;
final protected List< ViewId > viewIdsToProcess;
final BoundingBoxGUI bb;
final int[] blendingBorder;
final int[] blendingRange;
int minOverlappingViews;
double avgOverlappingViews;
ArrayList< ViewDescription > viewDescriptions;
HashMap< ViewId, RandomAccessibleInterval< FloatType > > imgs, weights;
ExtractPSF< FloatType > ePSF;
public static String[] files;
public static boolean debugImport = false;
public ProcessForDeconvolution(
final SpimData2 spimData,
final List< ViewId > viewIdsToProcess,
final BoundingBoxGUI bb,
final int[] blendingBorder,
final int[] blendingRange )
{
this.spimData = spimData;
this.viewIdsToProcess = viewIdsToProcess;
this.bb = bb;
this.blendingBorder = blendingBorder;
this.blendingRange = blendingRange;
}
public ExtractPSF< FloatType > getExtractPSF() { return ePSF; }
public HashMap< ViewId, RandomAccessibleInterval< FloatType > > getTransformedImgs() { return imgs; }
public HashMap< ViewId, RandomAccessibleInterval< FloatType > > getTransformedWeights() { return weights; }
public ArrayList< ViewDescription > getViewDescriptions() { return viewDescriptions; }
public int getMinOverlappingViews() { return minOverlappingViews; }
public double getAvgOverlappingViews() { return avgOverlappingViews; }
/**
* Fuses one stack, i.e. all angles/illuminations for one timepoint and channel
*
* @param timepoint
* @param channel
* @return
*/
public boolean fuseStacksAndGetPSFs(
final TimePoint timepoint,
final Channel channel,
final ImgFactory< FloatType > imgFactory,
final int osemIndex,
double osemspeedup,
WeightType weightType,
final HashMap< Channel, ChannelPSF > extractPSFLabels,
final long[] psfSize,
final HashMap< Channel, ArrayList< Pair< Pair< Angle, Illumination >, String > > > psfFiles,
final boolean transformLoadedPSFs )
{
// TODO: get rid of this hack
if ( files != null )
{
weightType = WeightType.LOAD_WEIGHTS;
IOFunctions.println( "WARNING: LOADING WEIGHTS FROM IMAGES, files.length()=" + files.length );
}
// get all views that are fused for this timepoint & channel
this.viewDescriptions = FusionHelper.assembleInputData( spimData, timepoint, channel, viewIdsToProcess );
if ( this.viewDescriptions.size() == 0 )
return false;
this.imgs = new HashMap< ViewId, RandomAccessibleInterval< FloatType > >();
this.weights = new HashMap< ViewId, RandomAccessibleInterval< FloatType > >();
final Img< FloatType > overlapImg;
if ( weightType == WeightType.WEIGHTS_ONLY )
overlapImg = imgFactory.create( bb.getDimensions(), new FloatType() );
else
overlapImg = null;
final boolean extractPSFs = (extractPSFLabels != null) && (extractPSFLabels.get( channel ).getLabel() != null);
final boolean loadPSFs = (psfFiles != null);
if ( extractPSFs )
ePSF = new ExtractPSF< FloatType >();
else if ( loadPSFs )
ePSF = loadPSFs( channel, viewDescriptions, psfFiles, transformLoadedPSFs );
else
{
ePSF = assignOtherChannel( channel, extractPSFLabels );
}
if ( ePSF == null )
return false;
// remember the extracted or loaded PSFs
extractPSFLabels.get( channel ).setExtractPSFInstance( ePSF );
// we will need to run some batches until all is fused
for ( int i = 0; i < viewDescriptions.size(); ++i )
{
final ViewDescription vd = viewDescriptions.get( i );
IOFunctions.println( "Transforming view " + i + " of " + (viewDescriptions.size()-1) + " (viewsetup=" + vd.getViewSetupId() + ", tp=" + vd.getTimePointId() + ")" );
IOFunctions.println("(" + new Date(System.currentTimeMillis()) + "): Reserving memory for transformed & weight image.");
// creating the output
RandomAccessibleInterval< FloatType > transformedImg; // might be null if WEIGHTS_ONLY
final RandomAccessibleInterval< FloatType > weightImg; // never null (except LOAD_WEIGHTS which is not implemented yet)
if ( weightType == WeightType.WEIGHTS_ONLY )
transformedImg = overlapImg;
else
transformedImg = imgFactory.create( bb.getDimensions(), new FloatType() );
IOFunctions.println("(" + new Date(System.currentTimeMillis()) + "): Transformed image factory: " + imgFactory.getClass().getSimpleName() );
// loading the input if necessary
final RandomAccessibleInterval< FloatType > img;
if ( weightType == WeightType.WEIGHTS_ONLY && !extractPSFs )
{
img = null;
}
else
{
IOFunctions.println("(" + new Date(System.currentTimeMillis()) + "): Loading image.");
img = ProcessFusion.getImage( new FloatType(), spimData, vd, true );
if ( Img.class.isInstance( img ) )
IOFunctions.println("(" + new Date(System.currentTimeMillis()) + "): Input image factory: " + ((Img< FloatType >)img).factory().getClass().getSimpleName() );
}
// initializing weights
IOFunctions.println("(" + new Date(System.currentTimeMillis()) + "): Initializing transformation & weights: " + weightType.name() );
spimData.getViewRegistrations().getViewRegistration( vd ).updateModel();
final AffineTransform3D transform = spimData.getViewRegistrations().getViewRegistration( vd ).getModel();
final long[] offset = new long[]{ bb.min( 0 ), bb.min( 1 ), bb.min( 2 ) };
if ( weightType == WeightType.PRECOMPUTED_WEIGHTS || weightType == WeightType.WEIGHTS_ONLY )
weightImg = imgFactory.create( bb.getDimensions(), new FloatType() );
else if ( weightType == WeightType.NO_WEIGHTS )
weightImg = Views.interval( new ConstantRandomAccessible< FloatType >( new FloatType( 1 ), transformedImg.numDimensions() ), transformedImg );
else if ( weightType == WeightType.VIRTUAL_WEIGHTS )
{
final Blending blending = getBlending( img, blendingBorder, blendingRange, vd );
weightImg = new TransformedRealRandomAccessibleInterval< FloatType >( blending, new FloatType(), transformedImg, transform, offset );
}
else //if ( processType == ProcessType.LOAD_WEIGHTS )
{
IOFunctions.println( "WARNING: LOADING WEIGHTS FROM: '" + new File( files[ i ] ) + "'" );
ImagePlus imp = LegacyStackImgLoaderIJ.open( new File( files[ i ] ) );
weightImg = imgFactory.create( bb.getDimensions(), new FloatType() );
LegacyStackImgLoaderIJ.imagePlus2ImgLib2Img( imp, (Img< FloatType > )weightImg, false );
imp.close();
if ( debugImport )
{
imp = ImageJFunctions.show( weightImg );
imp.setTitle( "ViewSetup " + vd.getViewSetupId() + " Timepoint " + vd.getTimePointId() );
}
}
// split up into many parts for multithreading
final Vector< ImagePortion > portions = FusionHelper.divideIntoPortions( Views.iterable( transformedImg ).size(), Threads.numThreads() * 4 );
// set up executor service
final ExecutorService taskExecutor = Executors.newFixedThreadPool( Threads.numThreads() );
final ArrayList< Callable< String > > tasks = new ArrayList< Callable< String > >();
IOFunctions.println("(" + new Date(System.currentTimeMillis()) + "): Transforming image & computing weights.");
for ( final ImagePortion portion : portions )
{
if ( weightType == WeightType.WEIGHTS_ONLY )
{
final Interval imgInterval = new FinalInterval( ViewSetupUtils.getSizeOrLoad( vd.getViewSetup(), vd.getTimePoint(), spimData.getSequenceDescription().getImgLoader() ) );
final Blending blending = getBlending( imgInterval, blendingBorder, blendingRange, vd );
tasks.add( new TransformWeights( portion, imgInterval, blending, transform, overlapImg, weightImg, offset ) );
}
else if ( weightType == WeightType.PRECOMPUTED_WEIGHTS )
{
final Blending blending = getBlending( img, blendingBorder, blendingRange, vd );
tasks.add( new TransformInputAndWeights( portion, img, blending, transform, transformedImg, weightImg, offset ) );
}
else if ( weightType == WeightType.NO_WEIGHTS || weightType == WeightType.VIRTUAL_WEIGHTS || weightType == WeightType.LOAD_WEIGHTS )
{
tasks.add( new TransformInput( portion, img, transform, transformedImg, offset ) );
}
else
{
throw new RuntimeException( weightType.name() + " not implemented yet." );
}
}
try
{
// invokeAll() returns when all tasks are complete
taskExecutor.invokeAll( tasks );
}
catch ( final InterruptedException e )
{
IOFunctions.println( "Failed to compute fusion: " + e );
e.printStackTrace();
return false;
}
taskExecutor.shutdown();
// extract PSFs if wanted
if ( extractPSFs )
{
final ArrayList< double[] > llist = getLocationsOfCorrespondingBeads( timepoint, vd, extractPSFLabels.get( channel ).getLabel() );
IOFunctions.println(
"(" + new Date(System.currentTimeMillis()) + "): Extracting PSF for viewsetup " + vd.getViewSetupId() +
" using label '" + extractPSFLabels.get( channel ).getLabel() + "'" + " (" +llist.size() + " corresponding detections available)" );
ePSF.extractNextImg( img, vd, transform, llist, psfSize );
}
if ( weightType != WeightType.WEIGHTS_ONLY )
imgs.put( vd, transformedImg );
weights.put( vd, weightImg );
// remove temporarily loaded image
tasks.clear();
System.gc();
}
// normalize the weights
final ArrayList< RandomAccessibleInterval< FloatType > > weightsSorted = new ArrayList< RandomAccessibleInterval< FloatType > >();
for ( final ViewDescription vd : viewDescriptions )
{
weightsSorted.add( weights.get( vd ) );
//new DisplayImage().exportImage( weights.get( vd ), "w " + vd.getViewSetupId() );
}
IOFunctions.println( "(" + new Date(System.currentTimeMillis()) + "): Computing weight normalization for deconvolution." );
final WeightNormalizer wn;
if ( weightType == WeightType.WEIGHTS_ONLY || weightType == WeightType.PRECOMPUTED_WEIGHTS || weightType == WeightType.LOAD_WEIGHTS )
wn = new WeightNormalizer( weightsSorted );
else if ( weightType == WeightType.VIRTUAL_WEIGHTS )
wn = new WeightNormalizer( weightsSorted, imgFactory );
else //if ( processType == ProcessType.NO_WEIGHTS )
wn = null;
if ( wn != null && !wn.process() )
return false;
// put the potentially modified weights back
for ( int i = 0; i < viewDescriptions.size(); ++i )
{
weights.put( viewDescriptions.get( i ), weightsSorted.get( i ) );
//new DisplayImage().exportImage( weightsSorted.get( i ), "w " + i );
}
if ( wn != null )
{
this.minOverlappingViews = wn.getMinOverlappingViews();
this.avgOverlappingViews = wn.getAvgOverlappingViews();
IOFunctions.println( "(" + new Date(System.currentTimeMillis()) + "): Minimal number of overlapping views: " + getMinOverlappingViews() + ", using " + ( this.minOverlappingViews = Math.max( 1, this.minOverlappingViews ) ) );
IOFunctions.println( "(" + new Date(System.currentTimeMillis()) + "): Average number of overlapping views: " + getAvgOverlappingViews() + ", using " + ( this.avgOverlappingViews = Math.max( 1, this.avgOverlappingViews ) ) );
}
if ( osemIndex == 1 )
osemspeedup = getMinOverlappingViews();
else if ( osemIndex == 2 )
osemspeedup = getAvgOverlappingViews();
IOFunctions.println( "(" + new Date(System.currentTimeMillis()) + "): Adjusting for OSEM speedup = " + osemspeedup );
if ( weightType == WeightType.WEIGHTS_ONLY )
displayWeights( osemspeedup, weightsSorted, overlapImg, imgFactory );
else
adjustForOSEM( weights, weightType, osemspeedup );
IOFunctions.println("(" + new Date(System.currentTimeMillis()) + "): Finished precomputations for deconvolution." );
//SimpleMultiThreading.threadHaltUnClean();
return true;
}
private static void adjustForOSEM( final HashMap< ViewId, RandomAccessibleInterval< FloatType > > weights, final WeightType weightType, final double osemspeedup )
{
if ( osemspeedup == 1.0 )
return;
if ( weightType == WeightType.PRECOMPUTED_WEIGHTS || weightType == WeightType.WEIGHTS_ONLY || weightType == WeightType.LOAD_WEIGHTS )
{
for ( final RandomAccessibleInterval< FloatType > w : weights.values() )
{
for ( final FloatType f : Views.iterable( w ) )
f.set( Math.min( 1, f.get() * (float)osemspeedup ) ); // individual contribution never higher than 1
}
}
else if ( weightType == WeightType.NO_WEIGHTS )
{
for ( final RandomAccessibleInterval< FloatType > w : weights.values() )
{
final RandomAccess< FloatType > r = w.randomAccess();
final long[] min = new long[ w.numDimensions() ];
w.min( min );
r.setPosition( min );
r.get().set( Math.min( 1, r.get().get() * (float)osemspeedup ) ); // individual contribution never higher than 1
}
}
else if ( weightType == WeightType.VIRTUAL_WEIGHTS )
{
for ( final RandomAccessibleInterval< FloatType > w : weights.values() )
((NormalizingRandomAccessibleInterval< FloatType >) w).setOSEMspeedup( osemspeedup );
}
else
{
throw new RuntimeException( "Weight Type: " + weightType.name() + " not supported in ProcessForDeconvolution.adjustForOSEM()" );
}
}
private ExtractPSF<FloatType> loadPSFs(
final Channel ch,
final ArrayList< ViewDescription > allInputData,
final HashMap< Channel, ArrayList< Pair< Pair< Angle, Illumination >, String > > > psfFiles,
final boolean transformLoadedPSFs )
{
final HashMap< ViewId, AffineTransform3D > models;
if ( transformLoadedPSFs )
{
models = new HashMap< ViewId, AffineTransform3D >();
for ( final ViewDescription viewDesc : allInputData )
models.put( viewDesc, spimData.getViewRegistrations().getViewRegistration( viewDesc ).getModel() );
}
else
{
models = null;
}
return ExtractPSF.loadAndTransformPSFs( psfFiles.get( ch ), allInputData, new FloatType(), models );
}
protected ExtractPSF< FloatType > assignOtherChannel( final Channel channel, final HashMap< Channel, ChannelPSF > extractPSFLabels )
{
final ChannelPSF thisChannelPSF = extractPSFLabels.get( channel );
final ChannelPSF otherChannelPSF = extractPSFLabels.get( thisChannelPSF.getOtherChannel() );
final Channel otherChannel = thisChannelPSF.getOtherChannel();
for ( int i = 0; i < viewDescriptions.size(); ++i )
{
// the viewid to map from
final ViewDescription sourceVD = viewDescriptions.get( i );
// search the viewid to map to
for ( final ViewId viewId : viewIdsToProcess )
{
final ViewDescription otherVD = spimData.getSequenceDescription().getViewDescription( viewId );
if (
otherVD.getViewSetup().getAngle().getId() == sourceVD.getViewSetup().getAngle().getId() &&
otherVD.getViewSetup().getIllumination().getId() == sourceVD.getViewSetup().getIllumination().getId() &&
otherVD.getTimePointId() == sourceVD.getTimePointId() &&
otherVD.getViewSetup().getChannel().getId() == otherChannel.getId() )
{
ePSF.getViewIdMapping().put( sourceVD, otherVD );
IOFunctions.println(
"ViewID=" + sourceVD.getViewSetupId() + ", TPID=" + sourceVD.getTimePointId() +
" takes the PSF from " +
"ViewID=" + otherVD.getViewSetupId() + ", TPID=" + otherVD.getTimePointId() );
}
}
}
return otherChannelPSF.getExtractPSFInstance();
}
protected ArrayList< double[] > getLocationsOfCorrespondingBeads( final TimePoint tp, final ViewDescription inputData, final String label )
{
final InterestPointList iplist = spimData.getViewInterestPoints().getViewInterestPointLists( inputData ).getInterestPointList( label );
// we use a hashset as a detection can correspond with several other detections, and we only want it once
final HashSet< Integer > ipWithCorrespondences = new HashSet< Integer >();
for ( final CorrespondingInterestPoints cip : iplist.getCorrespondingInterestPoints() )
ipWithCorrespondences.add( cip.getDetectionId() );
final ArrayList< double[] > llist = new ArrayList< double[] >();
// now go over all detections and see if they had correspondences
for ( final InterestPoint ip : iplist.getInterestPoints() )
if ( ipWithCorrespondences.contains( ip.getId() ) )
llist.add( ip.getL().clone() );
return llist;
}
protected void displayWeights(
final double osemspeedup,
final ArrayList< RandomAccessibleInterval< FloatType > > weights,
final RandomAccessibleInterval< FloatType > overlapImg,
final ImgFactory< FloatType > imgFactory )
{
final DisplayImage d = new DisplayImage();
d.exportImage( overlapImg, bb, "Number of views per pixel" );
final Img< FloatType > w = imgFactory.create( overlapImg, new FloatType() );
final Img< FloatType > wosem = imgFactory.create( overlapImg, new FloatType() );
// split up into many parts for multithreading
final Vector< ImagePortion > portions = FusionHelper.divideIntoPortions( Views.iterable( weights.get( 0 ) ).size(), Threads.numThreads() * 2 );
// set up executor service
final ExecutorService taskExecutor = Executors.newFixedThreadPool( Threads.numThreads() );
final ArrayList< Callable< String > > tasks = new ArrayList< Callable< String > >();
for ( final ImagePortion portion : portions )
{
tasks.add( new Callable< String >()
{
@Override
public String call() throws Exception
{
final ArrayList< Cursor< FloatType > > cursors = new ArrayList< Cursor< FloatType > >();
final Cursor< FloatType > sum = w.cursor();
final Cursor< FloatType > sumOsem = wosem.cursor();
for ( final RandomAccessibleInterval< FloatType > imgW : weights )
{
final Cursor< FloatType > c = Views.iterable( imgW ).cursor();
c.jumpFwd( portion.getStartPosition() );
cursors.add( c );
}
sum.jumpFwd( portion.getStartPosition() );
sumOsem.jumpFwd( portion.getStartPosition() );
for ( long j = 0; j < portion.getLoopSize(); ++j )
{
double sumW = 0;
double sumOsemW = 0;
for ( final Cursor< FloatType > c : cursors )
{
final float w = c.next().get();
sumW += w;
sumOsemW += Math.min( 1, w * osemspeedup );
}
sum.next().set( (float)sumW );
sumOsem.next().set( (float)sumOsemW );
}
return "done.";
}
});
}
// run threads
try
{
// invokeAll() returns when all tasks are complete
taskExecutor.invokeAll( tasks );
}
catch ( final Exception e )
{
IOFunctions.println( "Failed to compute weight normalization for deconvolution: " + e );
e.printStackTrace();
return;
}
taskExecutor.shutdown();
d.exportImage( w, bb, "Sum of weights per pixel" );
d.exportImage( wosem, bb, "OSEM=" + osemspeedup + ", sum of weights per pixel" );
}
protected Blending getBlending( final Interval interval, final int[] blendingBorder, final int[] blendingRange, final ViewDescription desc )
{
final float[] blending = new float[ 3 ];
final float[] border = new float[ 3 ];
blending[ 0 ] = blendingRange[ 0 ];
blending[ 1 ] = blendingRange[ 1 ];
blending[ 2 ] = blendingRange[ 2 ];
border[ 0 ] = blendingBorder[ 0 ];
border[ 1 ] = blendingBorder[ 1 ];
border[ 2 ] = blendingBorder[ 2 ];
return new Blending( interval, border, blending );
}
}