/*-
* #%L
* Fiji distribution of ImageJ for the life sciences.
* %%
* Copyright (C) 2007 - 2017 Fiji developers.
* %%
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as
* published by the Free Software Foundation, either version 2 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public
* License along with this program. If not, see
* <http://www.gnu.org/licenses/gpl-2.0.html>.
* #L%
*/
package spim.process.fusion.deconvolution;
import java.io.File;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.Vector;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import ij.CompositeImage;
import ij.ImagePlus;
import ij.ImageStack;
import mpicbg.spim.io.IOFunctions;
import net.imglib2.Cursor;
import net.imglib2.Dimensions;
import net.imglib2.IterableInterval;
import net.imglib2.RandomAccess;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.exception.IncompatibleTypeException;
import net.imglib2.img.Img;
import net.imglib2.img.ImgFactory;
import net.imglib2.multithreading.SimpleMultiThreading;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.RealSum;
import net.imglib2.util.Util;
import net.imglib2.view.Views;
import spim.Threads;
import spim.fiji.ImgLib2Temp.Pair;
import spim.fiji.spimdata.imgloaders.LegacyStackImgLoaderIJ;
import spim.process.fusion.FusionHelper;
import spim.process.fusion.ImagePortion;
import spim.process.fusion.deconvolution.MVDeconFFT.PSFTYPE;
import spim.process.fusion.export.DisplayImage;
public class MVDeconvolution
{
// if you want to start from a certain iteration
public static String initialImage = null;
// check in advance if values are reasonable
public static boolean checkNumbers = true;
public static boolean debug = true;
public static int debugInterval = 1;
public static boolean setBackgroundToAvg = true;//false;
final static float minValue = 0.0001f;
final int numViews, numDimensions;
final double lambda;
ImageStack stack;
CompositeImage ci;
boolean collectStatistics = true;
// current iteration
int i = 0;
// the multi-view deconvolved image
Img< FloatType > psi;
// temporary images that are reused for computation
final Img< FloatType > tmp1, tmp2;
// the input data
final MVDeconInput views;
ArrayList< MVDeconFFT > data;
String name;
public MVDeconvolution(
final MVDeconInput views,
final PSFTYPE iterationType,
final int numIterations,
final double lambda,
double osemspeedup,
final int osemspeedupindex,
final String name ) throws IncompatibleTypeException
{
this.psi = null;
this.name = name;
this.data = views.getViews();
this.views = views;
this.numViews = data.size();
this.numDimensions = data.get( 0 ).getImage().numDimensions();
this.lambda = lambda;
IOFunctions.println( "(" + new Date(System.currentTimeMillis()) + "): Deconvolved & temporary image factory: " + views.imgFactory().getClass().getSimpleName() );
// init all views
views.init( iterationType );
if ( initialImage != null )
{
IOFunctions.println( "(" + new Date(System.currentTimeMillis()) + "): Loading intial image '" + initialImage + "'" );
this.psi = loadInitialImage(
initialImage,
checkNumbers,
minValue,
data.get( 0 ).getImage(),
views.imgFactory() );
}
else
{
// the real data image psi is initialized with the fused image
// if there was no initial guess loaded
IOFunctions.println( "(" + new Date(System.currentTimeMillis()) + "): Fusing image for first iteration" );
this.psi = views.imgFactory().create( data.get( 0 ).getImage(), new FloatType() );
double avg = fuseFirstIteration( psi, views.getViews() );
IOFunctions.println( "(" + new Date(System.currentTimeMillis()) + "): Average intensity in overlapping area: " + avg );
if ( Double.isNaN( avg ) )
{
avg = 0.5;
IOFunctions.println( "(" + new Date(System.currentTimeMillis()) + "): ERROR! Computing average FAILED, is NaN, setting it to: " + avg );
}
IOFunctions.println( "(" + new Date(System.currentTimeMillis()) + "): Setting image to average intensity: " + avg );
for ( final FloatType t : psi )
t.set( (float)avg );
}
//new DisplayImage().exportImage( psi, "psi" );
// instantiate the temporary images
this.tmp1 = views.imgFactory().create( psi, new FloatType() );
this.tmp2 = views.imgFactory().create( psi, new FloatType() );
// run the deconvolution
while ( i < numIterations )
{
// show the fused image first
if ( debug && (i-1) % debugInterval == 0 )
{
// if it is slices, wrap & copy otherwise virtual & copy - never use the actual image
// as it is being updated in the process
final ImagePlus tmp = DisplayImage.getImagePlusInstance( psi, true, "Psi", 0, 1 ).duplicate();
if ( this.stack == null )
{
this.stack = tmp.getImageStack();
for ( int i = 0; i < this.psi.dimension( 2 ); ++i )
this.stack.setSliceLabel( "Iteration 1", i + 1 );
tmp.setTitle( "debug view" );
this.ci = new CompositeImage( tmp, CompositeImage.COMPOSITE );
this.ci.setDimensions( 1, (int)this.psi.dimension( 2 ), 1 );
this.ci.show();
}
else if ( stack.getSize() == this.psi.dimension( 2 ) )
{
final ImageStack t = tmp.getImageStack();
for ( int i = 0; i < this.psi.dimension( 2 ); ++i )
this.stack.addSlice( "Iteration 2", t.getProcessor( i + 1 ) );
this.ci.hide();
this.ci = new CompositeImage( new ImagePlus( "debug view", this.stack ), CompositeImage.COMPOSITE );
this.ci.setDimensions( 1, (int)this.psi.dimension( 2 ), 2 );
this.ci.show();
}
else
{
final ImageStack t = tmp.getImageStack();
for ( int i = 0; i < this.psi.dimension( 2 ); ++i )
this.stack.addSlice( "Iteration " + i, t.getProcessor( i + 1 ) );
this.ci.setStack( this.stack, 1, (int)this.psi.dimension( 2 ), stack.getSize() / (int)this.psi.dimension( 2 ) );
}
}
runIteration();
}
IOFunctions.println( "Masking never updated pixels." );
fuseFirstIteration( tmp1, views.getViews() );
final Cursor< FloatType > tmp1c = tmp1.cursor();
for ( final FloatType t : psi )
if ( tmp1c.next().get() == 0 )
t.set( 0 );
IOFunctions.println( "DONE (" + new Date(System.currentTimeMillis()) + ")." );
}
protected static final double fuseFirstIteration( final Img< FloatType > psi, final ArrayList< MVDeconFFT > views )
{
final int nThreads = Threads.numThreads();
final int nPortions = nThreads * 2;
// split up into many parts for multithreading
final Vector< ImagePortion > portions = FusionHelper.divideIntoPortions( psi.size(), nPortions );
final ArrayList< Callable< Pair< RealSum, Long > > > tasks = new ArrayList< Callable< Pair< RealSum, Long > > >();
final ExecutorService taskExecutor = Executors.newFixedThreadPool( nThreads );
final ArrayList< RandomAccessibleInterval< FloatType > > imgs = new ArrayList< RandomAccessibleInterval< FloatType > >();
for ( final MVDeconFFT mvdecon : views )
imgs.add( mvdecon.getImage() );
for ( final ImagePortion portion : portions )
tasks.add( new FirstIteration( portion, psi, imgs ) );
final RealSum s = new RealSum();
long count = 0;
try
{
// invokeAll() returns when all tasks are complete
final List< Future< Pair< RealSum, Long > > > imgIntensities = taskExecutor.invokeAll( tasks );
for ( final Future< Pair< RealSum, Long > > future : imgIntensities )
{
s.add( future.get().getA().getSum() );
count += future.get().getB().longValue();
}
}
catch ( final Exception e )
{
IOFunctions.println( "Failed to fuse initial iteration: " + e );
e.printStackTrace();
return -1;
}
taskExecutor.shutdown();
return s.getSum() / (double)count;
}
protected static Img< FloatType > loadInitialImage(
final String fileName,
final boolean checkNumbers,
final float minValue,
final Dimensions dimensions,
final ImgFactory< FloatType > imageFactory )
{
IOFunctions.println( "Loading image '" + fileName + "' as start for iteration." );
final ImagePlus impPSI = LegacyStackImgLoaderIJ.open( new File( fileName ) );
if ( impPSI == null )
{
IOFunctions.println( "Could not load image '" + fileName + "'." );
return null;
}
final long[] dimPsi = impPSI.getStack().getSize() == 1 ?
new long[]{ impPSI.getWidth(), impPSI.getHeight() } : new long[]{ impPSI.getWidth(), impPSI.getHeight(), impPSI.getStack().getSize() };
final Img< FloatType > psi = imageFactory.create( dimPsi, new FloatType() );
LegacyStackImgLoaderIJ.imagePlus2ImgLib2Img( impPSI, psi, false );
if ( psi == null )
{
IOFunctions.println( "Could not load image '" + fileName + "'." );
return null;
}
else
{
boolean dimensionsMatch = true;
final long dim[] = new long[ dimensions.numDimensions() ];
for ( int d = 0; d < psi.numDimensions(); ++d )
{
if ( psi.dimension( d ) != dimensions.dimension( d ) )
dimensionsMatch = false;
dim[ d ] = dimensions.dimension( d );
}
if ( !dimensionsMatch )
{
IOFunctions.println(
"Dimensions of '" + fileName + "' do not match: " +
Util.printCoordinates( dimPsi ) + " != " + Util.printCoordinates( dim ) );
return null;
}
if ( checkNumbers )
{
IOFunctions.println(
"Checking values of '" + fileName + "' you can disable this check by setting " +
"spim.process.fusion.deconvolution.MVDeconvolution.checkNumbers = false;" );
boolean smaller = false;
boolean hasZerosOrNeg = false;
for ( final FloatType v : psi )
{
if ( v.get() < minValue )
smaller = true;
if ( v.get() <= 0 )
{
hasZerosOrNeg = true;
v.set( minValue );
}
}
if ( smaller )
IOFunctions.println(
"Some values '" + fileName + "' are smaller than the minimal value of " +
minValue + ", this can lead to instabilities." );
if ( hasZerosOrNeg )
IOFunctions.println(
"Some values '" + fileName + "' were smaller or equal to zero," +
"they have been replaced with the min value of " + minValue );
}
}
return psi;
}
public MVDeconInput getData() { return views; }
public String getName() { return name; }
public Img< FloatType > getPsi() { return psi; }
public int getCurrentIteration() { return i; }
public void runIteration()
{
runIteration( psi, tmp1, tmp2, data, lambda, minValue, collectStatistics, i++ );
}
final private static void runIteration(
final Img< FloatType > psi,
final Img< FloatType > tmp1, // a temporary image using the same ImgFactory as PSI
final Img< FloatType > tmp2, // a temporary image using the same ImgFactory as PSI
final ArrayList< MVDeconFFT > data,
final double lambda,
final float minValue,
final boolean collectStatistic,
final int iteration )
{
IOFunctions.println( "iteration: " + iteration + " (" + new Date(System.currentTimeMillis()) + ")" );
final int numViews = data.size();
final int nThreads = Threads.numThreads();
final int nPortions = nThreads * 2;
// split up into many parts for multithreading
final Vector< ImagePortion > portions = FusionHelper.divideIntoPortions( psi.size(), nPortions );
final ArrayList< Callable< Void > > tasks = new ArrayList< Callable< Void > >();
for ( int view = 0; view < numViews; ++view )
{
final MVDeconFFT processingData = data.get( view );
//
// convolve psi (current guess of the image) with the PSF of the current view
// [psi >> tmp1]
//
processingData.convolve1( psi, tmp1 );
//new DisplayImage().exportImage( tmp1, "psi" );
//new DisplayImage().exportImage( tmp1, "psi blurred" );
//
// compute quotient img/psiBlurred
// [tmp1, img >> tmp1]
//
tasks.clear();
for ( final ImagePortion portion : portions )
{
tasks.add( new Callable< Void >()
{
@Override
public Void call() throws Exception
{
computeQuotient( portion.getStartPosition(), portion.getLoopSize(), tmp1, processingData.getImage() );
return null;
}
});
}
execTasks( tasks, nThreads, "compute quotient" );
//new DisplayImage().exportImage( processingData.getImage(), "img" );
//new DisplayImage().exportImage( tmp1, "quotient" );
//
// blur the residuals image with the kernel
// (this cannot be don in-place as it might be computed in blocks sequentially,
// and the input for the n+1'th block cannot be formed by the written back output
// of the n'th block)
// [tmp1 >> tmp2]
//
processingData.convolve2( tmp1, tmp2 );
//new DisplayImage().exportImage( tmp2, "quotient blurred" );
//
// compute final values
// [psi, weights, tmp2 >> psi]
//
final double[][] sumMax = new double[ nPortions ][ 2 ];
tasks.clear();
for ( int i = 0; i < portions.size(); ++i )
{
final ImagePortion portion = portions.get( i );
final int portionId = i;
tasks.add( new Callable< Void >()
{
@Override
public Void call() throws Exception
{
computeFinalValues( portion.getStartPosition(), portion.getLoopSize(), psi, tmp2, processingData.getWeight(), lambda, sumMax[ portionId ] );
return null;
}
});
}
execTasks( tasks, nThreads, "compute final values" );
// accumulate the results from the individual threads
double sumChange = 0;
double maxChange = -1;
for ( int i = 0; i < nPortions; ++i )
{
sumChange += sumMax[ i ][ 0 ];
maxChange = Math.max( maxChange, sumMax[ i ][ 1 ] );
}
IOFunctions.println( "iteration: " + iteration + ", view: " + view + " --- sum change: " + sumChange + " --- max change per pixel: " + maxChange );
//new DisplayImage().exportImage( processingData.getWeight(), "weight" );
//new DisplayImage().exportImage( psi, "psi new" );
//SimpleMultiThreading.threadHaltUnClean();
}
//SimpleMultiThreading.threadHaltUnClean();
}
private static final void execTasks( final ArrayList< Callable< Void > > tasks, final int nThreads, final String jobDescription )
{
final ExecutorService taskExecutor = Executors.newFixedThreadPool( nThreads );
try
{
// invokeAll() returns when all tasks are complete
taskExecutor.invokeAll( tasks );
}
catch ( final InterruptedException e )
{
IOFunctions.println( "Failed to " + jobDescription + ": " + e );
e.printStackTrace();
return;
}
taskExecutor.shutdown();
}
/**
* One thread of a method to compute the quotient between two images of the multiview deconvolution
*
* @param start
* @param loopSize
* @param psiBlurred
* @param observedImg
*/
private static final void computeQuotient(
final long start,
final long loopSize,
final RandomAccessibleInterval< FloatType > psiBlurred,
final RandomAccessibleInterval< FloatType > observedImg )
{
final IterableInterval< FloatType > psiBlurredIterable = Views.iterable( psiBlurred );
final IterableInterval< FloatType > observedImgIterable = Views.iterable( observedImg );
if ( psiBlurredIterable.iterationOrder().equals( observedImgIterable.iterationOrder() ) )
{
final Cursor< FloatType > cursorPsiBlurred = psiBlurredIterable.cursor();
final Cursor< FloatType > cursorImg = observedImgIterable.cursor();
cursorPsiBlurred.jumpFwd( start );
cursorImg.jumpFwd( start );
for ( long l = 0; l < loopSize; ++l )
{
cursorPsiBlurred.fwd();
cursorImg.fwd();
final float psiBlurredValue = cursorPsiBlurred.get().get();
final float imgValue = cursorImg.get().get();
if ( imgValue > 0 )
cursorPsiBlurred.get().set( imgValue / psiBlurredValue );
else
cursorPsiBlurred.get().set( 1 ); // no image data, quotient=1
}
}
else
{
final RandomAccess< FloatType > raPsiBlurred = psiBlurred.randomAccess();
final Cursor< FloatType > cursorImg = observedImgIterable.localizingCursor();
cursorImg.jumpFwd( start );
for ( long l = 0; l < loopSize; ++l )
{
cursorImg.fwd();
raPsiBlurred.setPosition( cursorImg );
final float psiBlurredValue = raPsiBlurred.get().get();
final float imgValue = cursorImg.get().get();
if ( imgValue > 0 )
raPsiBlurred.get().set( imgValue / psiBlurredValue );
else
raPsiBlurred.get().set( 1 ); // no image data, quotient=1
}
}
}
/**
* One thread of a method to compute the quotient between two images of the multiview deconvolution
*
* @param start
* @param loopSize
* @param source
* @param target
*/
public static final void copyImg(
final long start,
final long loopSize,
final RandomAccessibleInterval< FloatType > source,
final RandomAccessibleInterval< FloatType > target )
{
final IterableInterval< FloatType > sourceIterable = Views.iterable( source );
final IterableInterval< FloatType > targetIterable = Views.iterable( target );
if ( sourceIterable.iterationOrder().equals( sourceIterable.iterationOrder() ) )
{
final Cursor< FloatType > cursorSource = sourceIterable.cursor();
final Cursor< FloatType > cursorTarget = targetIterable.cursor();
cursorSource.jumpFwd( start );
cursorTarget.jumpFwd( start );
for ( long l = 0; l < loopSize; ++l )
cursorTarget.next().set( cursorSource.next() );
}
else
{
final RandomAccess< FloatType > raSource = source.randomAccess();
final Cursor< FloatType > cursorTarget = targetIterable.localizingCursor();
cursorTarget.jumpFwd( start );
for ( long l = 0; l < loopSize; ++l )
{
cursorTarget.fwd();
raSource.setPosition( cursorTarget );
cursorTarget.get().set( raSource.get() );
}
}
}
/**
* One thread of a method to compute the final values of one iteration of the multiview deconvolution
*
* @param start
* @param loopSize
* @param psi
* @param integral
* @param weight
* @param lambda
*/
private static final void computeFinalValues(
final long start,
final long loopSize,
final RandomAccessibleInterval< FloatType > psi,
final RandomAccessibleInterval< FloatType > integral,
final RandomAccessibleInterval< FloatType > weight,
final double lambda,
final double[] sumMax )
{
double sumChange = 0;
double maxChange = -1;
final IterableInterval< FloatType > psiIterable = Views.iterable( psi );
final IterableInterval< FloatType > integralIterable = Views.iterable( integral );
final IterableInterval< FloatType > weightIterable = Views.iterable( weight );
if (
psiIterable.iterationOrder().equals( integralIterable.iterationOrder() ) &&
psiIterable.iterationOrder().equals( weightIterable.iterationOrder() ) )
{
final Cursor< FloatType > cursorPsi = psiIterable.cursor();
final Cursor< FloatType > cursorIntegral = integralIterable.cursor();
final Cursor< FloatType > cursorWeight = weightIterable.cursor();
cursorPsi.jumpFwd( start );
cursorIntegral.jumpFwd( start );
cursorWeight.jumpFwd( start );
for ( long l = 0; l < loopSize; ++l )
{
cursorPsi.fwd();
cursorIntegral.fwd();
cursorWeight.fwd();
// get the final value
final float lastPsiValue = cursorPsi.get().get();
final float nextPsiValue = computeNextValue( lastPsiValue, cursorIntegral.get().get(), cursorWeight.get().get(), lambda );
// store the new value
cursorPsi.get().set( (float)nextPsiValue );
// statistics
final float change = change( lastPsiValue, nextPsiValue );
sumChange += change;
maxChange = Math.max( maxChange, change );
}
}
else
{
final Cursor< FloatType > cursorPsi = psiIterable.localizingCursor();
final RandomAccess< FloatType > raIntegral = integral.randomAccess();
final RandomAccess< FloatType > raWeight = weight.randomAccess();
cursorPsi.jumpFwd( start );
for ( long l = 0; l < loopSize; ++l )
{
cursorPsi.fwd();
raIntegral.setPosition( cursorPsi );
raWeight.setPosition( cursorPsi );
// get the final value
final float lastPsiValue = cursorPsi.get().get();
float nextPsiValue = computeNextValue( lastPsiValue, raIntegral.get().get(), raWeight.get().get(), lambda );
// store the new value
cursorPsi.get().set( (float)nextPsiValue );
// statistics
final float change = change( lastPsiValue, nextPsiValue );
sumChange += change;
maxChange = Math.max( maxChange, change );
}
}
sumMax[ 0 ] = sumChange;
sumMax[ 1 ] = maxChange;
}
private static final float change( final float lastPsiValue, final float nextPsiValue ) { return Math.abs( ( nextPsiValue - lastPsiValue ) ); }
/**
* compute the next value for a specific pixel
*
* @param lastPsiValue - the previous value
* @param integralValue - result from the integral
* @param lambda - if > 0, regularization
* @return
*/
private static final float computeNextValue( final float lastPsiValue, final float integralValue, final float weight, final double lambda )
{
final float value = lastPsiValue * integralValue;
final float adjustedValue;
if ( value > 0 )
{
//
// perform Tikhonov regularization if desired
//
if ( lambda > 0 )
adjustedValue = (float)tikhonov( value, lambda );
else
adjustedValue = value;
}
else
{
adjustedValue = minValue;
}
//
// get the final value and some statistics
//
final float nextPsiValue;
if ( Double.isNaN( adjustedValue ) )
nextPsiValue = (float)minValue;
else
nextPsiValue = (float)Math.max( minValue, adjustedValue );
// compute the difference between old and new and apply the appropriate amount
return lastPsiValue + ( ( nextPsiValue - lastPsiValue ) * weight );
}
private static final double tikhonov( final double value, final double lambda ) { return ( Math.sqrt( 1.0 + 2.0*lambda*value ) - 1.0 ) / lambda; }
public static void main( String[] args )
{
for ( double d = 0; d < 10; d = d + 0.1 )
{
System.out.println( d*10000 + ": " + tikhonov( d*10000, 0.0006 ) );
System.out.println( d + ": " + tikhonov( d, 0.0006 ) );
}
}
}