/*-
* #%L
* Fiji distribution of ImageJ for the life sciences.
* %%
* Copyright (C) 2007 - 2017 Fiji developers.
* %%
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as
* published by the Free Software Foundation, either version 2 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public
* License along with this program. If not, see
* <http://www.gnu.org/licenses/gpl-2.0.html>.
* #L%
*/
package mpicbg.spim.postprocessing.deconvolution2;
import ij.CompositeImage;
import ij.IJ;
import ij.ImagePlus;
import ij.ImageStack;
import java.util.ArrayList;
import java.util.Date;
import java.util.Vector;
import java.util.concurrent.atomic.AtomicInteger;
import spim.Threads;
import mpicbg.imglib.cursor.Cursor;
import mpicbg.imglib.image.Image;
import mpicbg.imglib.image.ImageFactory;
import mpicbg.imglib.image.display.imagej.ImageJFunctions;
import mpicbg.imglib.io.LOCI;
import mpicbg.imglib.multithreading.Chunk;
import mpicbg.imglib.multithreading.SimpleMultiThreading;
import mpicbg.imglib.type.numeric.real.FloatType;
import mpicbg.spim.io.IOFunctions;
import mpicbg.spim.postprocessing.deconvolution2.LRFFT.PSFTYPE;
import net.imglib2.util.Util;
public class BayesMVDeconvolution implements Deconvolver
{
// if you want to start from a certain iteration
public static String initialImage = null;
// check in advance if values are reasonable
public static boolean checkNumbers = true;
public static boolean debug = true;
public static int debugInterval = 1;
final static float minValue = 0.0001f;
final int numViews, numDimensions;
final float avg;
final double lambda;
ImageStack stack;
CompositeImage ci;
boolean collectStatistics = true;
// current iteration
int i = 0;
// the multi-view deconvolved image
Image<FloatType> psi;
// the input data
final LRInput views;
ArrayList<LRFFT> data;
String name;
public BayesMVDeconvolution( final LRInput views, final PSFTYPE iterationType, final int numIterations, final double lambda, double osemspeedup, final int osemspeedupindex, final String name )
{
this.name = name;
this.data = views.getViews();
this.views = views;
this.numViews = data.size();
this.numDimensions = data.get( 0 ).getImage().getNumDimensions();
this.lambda = lambda;
if ( initialImage != null )
this.psi = loadInitialImage( initialImage, checkNumbers, minValue, data.get( 0 ).getImage().getDimensions(), data.get( 0 ).getImage().getImageFactory() );
final double[] result = AdjustInput.normAllImages( data );
this.avg = (float)result[ 0 ];
if ( osemspeedupindex == 1 )//min
osemspeedup = Math.max( 1, result[ 1 ] );//but not smaller than 1
else if ( osemspeedupindex == 2 )//avg
osemspeedup = Math.max( 1, result[ 2 ] );//but not smaller than 1
adjustOSEMspeedup( views, osemspeedup );
IJ.log( "Average intensity in overlapping area: " + avg );
IJ.log( "OSEM acceleration: " + osemspeedup );
// init all views
views.init( iterationType );
//
// the real data image psi is initialized with the average
// if there was no initial guess loaded
//
if ( this.psi == null )
{
this.psi = data.get( 0 ).getImage().createNewImage( "psi (deconvolved image)" );
for ( final FloatType f : psi )
f.set( avg );
}
IOFunctions.println( "Deconvolved image container: " + psi.getImageFactory().getContainerFactory().getClass().getSimpleName() );
//this.stack = new ImageStack( this.psi.getDimension( 0 ), this.psi.getDimension( 1 ) );
// run the deconvolution
while ( i < numIterations )
{
runIteration();
if ( debug && (i-1) % debugInterval == 0 )
{
psi.getDisplay().setMinMax( 0, 1 );
final ImagePlus tmp = ImageJFunctions.copyToImagePlus( psi );
if ( this.stack == null )
{
this.stack = tmp.getImageStack();
for ( int i = 0; i < this.psi.getDimension( 2 ); ++i )
this.stack.setSliceLabel( "Iteration 1", i + 1 );
tmp.setTitle( "debug view" );
this.ci = new CompositeImage( tmp, CompositeImage.COMPOSITE );
this.ci.setDimensions( 1, this.psi.getDimension( 2 ), 1 );
this.ci.show();
}
else if ( stack.getSize() == this.psi.getDimension( 2 ) )
{
IJ.log( "Stack size = " + this.stack.getSize() );
final ImageStack t = tmp.getImageStack();
for ( int i = 0; i < this.psi.getDimension( 2 ); ++i )
this.stack.addSlice( "Iteration 2", t.getProcessor( i + 1 ) );
IJ.log( "Stack size = " + this.stack.getSize() );
this.ci.hide();
IJ.log( "Stack size = " + this.stack.getSize() );
this.ci = new CompositeImage( new ImagePlus( "debug view", this.stack ), CompositeImage.COMPOSITE );
this.ci.setDimensions( 1, this.psi.getDimension( 2 ), 2 );
this.ci.show();
}
else
{
final ImageStack t = tmp.getImageStack();
for ( int i = 0; i < this.psi.getDimension( 2 ); ++i )
this.stack.addSlice( "Iteration " + i, t.getProcessor( i + 1 ) );
this.ci.setStack( this.stack, 1, this.psi.getDimension( 2 ), stack.getSize() / this.psi.getDimension( 2 ) );
}
/*
Image<FloatType> psiCopy = psi.clone();
//ViewDataBeads.normalizeImage( psiCopy );
psiCopy.setName( "Iteration " + i + " l=" + lambda );
psiCopy.getDisplay().setMinMax( 0, 1 );
ImageJFunctions.copyToImagePlus( psiCopy ).show();
psiCopy.close();
psiCopy = null;*/
}
}
IJ.log( "DONE (" + new Date(System.currentTimeMillis()) + ")." );
}
private void adjustOSEMspeedup( final LRInput views, final double osemspeedup )
{
if ( osemspeedup == 1.0 )
return;
for ( final LRFFT view : views.getViews() )
{
for ( final FloatType f : view.getWeight() )
f.set( Math.min( 1, f.get() * (float)osemspeedup ) ); // individual contribution never higher than 1
}
}
protected static Image< FloatType > loadInitialImage( final String fileName, final boolean checkNumbers, final float minValue, final int[] dimensions, final ImageFactory< FloatType > imageFactory )
{
IOFunctions.println( "Loading image '" + fileName + "' as start for iteration." );
Image< FloatType > psi = LOCI.openLOCIFloatType( fileName, imageFactory );
if ( psi == null )
{
IOFunctions.println( "Could not load image '" + fileName + "'." );
return null;
}
else
{
boolean dimensionsMatch = true;
for ( int d = 0; d < psi.getNumDimensions(); ++d )
if ( psi.getDimension( d ) != dimensions[ d ] )
dimensionsMatch = false;
if ( !dimensionsMatch )
{
IOFunctions.println( "Dimensions of '" + fileName + "' do not match: " + Util.printCoordinates( psi.getDimensions() ) + " != " + Util.printCoordinates( dimensions ) );
psi.close();
return null;
}
if ( checkNumbers )
{
IOFunctions.println( "Checking values of '" + fileName + "' you can disable this check by setting mpicbg.spim.postprocessing.deconvolution2.BayesMVDeconvolution.checkNumbers = false;" );
boolean smaller = false;
boolean hasZerosOrNeg = false;
for ( final FloatType v : psi )
{
if ( v.get() < minValue )
smaller = true;
if ( v.get() <= 0 )
{
hasZerosOrNeg = true;
v.set( minValue );
}
}
if ( smaller )
IOFunctions.println( "Some values '" + fileName + "' are smaller than the minimal value of " + minValue + ", this can lead to instabilities." );
if ( hasZerosOrNeg )
IOFunctions.println( "Some values '" + fileName + "' were smaller or equal to zero, they have been replaced with the min value of " + minValue );
}
}
return psi;
}
public LRInput getData() { return views; }
public String getName() { return name; }
public double getAvg() { return avg; }
public Image<FloatType> getPsi() { return psi; }
public int getCurrentIteration() { return i; }
public void runIteration()
{
runIteration( psi, data, lambda, minValue, collectStatistics, i++ );
}
final private static void runIteration( final Image< FloatType> psi, final ArrayList< LRFFT > data,
final double lambda, final float minValue, final boolean collectStatistic, final int iteration )
{
IJ.log( "iteration: " + iteration + " (" + new Date(System.currentTimeMillis()) + ")" );
final int numViews = data.size();
final Vector< Chunk > threadChunks = SimpleMultiThreading.divideIntoChunks( psi.getNumPixels(), Threads.numThreads() );
final int numThreads = threadChunks.size();
final Image< FloatType > lastIteration;
if ( collectStatistic )
lastIteration = psi.clone();
else
lastIteration = null;
//int view = iteration % numViews;
for ( int view = 0; view < numViews; ++view )
{
final LRFFT processingData = data.get( view );
long time = System.currentTimeMillis();
// convolve psi (current guess of the image) with the PSF of the current view
final Image<FloatType> psiBlurred = processingData.convolve1( psi );
//System.out.println( view + " 1: " + fftConvolution.getProcessingTime() + " ms." );
System.out.println( view + " a: " + (time - System.currentTimeMillis()) + " ms." );
// size = 666, 363, 537
// compute quotient img/psiBlurred
final AtomicInteger ai = new AtomicInteger(0);
final Thread[] threads = SimpleMultiThreading.newThreads( numThreads );
for ( int ithread = 0; ithread < threads.length; ++ithread )
threads[ithread] = new Thread(new Runnable()
{
public void run()
{
// Thread ID
final int myNumber = ai.getAndIncrement();
// get chunk of pixels to process
final Chunk myChunk = threadChunks.get( myNumber );
computeQuotient( myChunk.getStartPosition(), myChunk.getLoopSize(), psiBlurred, processingData );
}
});
SimpleMultiThreading.startAndJoin( threads );
//System.out.println( view + " b: " + (time - System.currentTimeMillis()) + " ms." );
time = System.currentTimeMillis();
// blur the residuals image with the kernel
final Image< FloatType > integral = processingData.convolve2( psiBlurred );
//System.out.println( view + " 2: " + invFFConvolution.getProcessingTime() + " ms." );
System.out.println( view + " b: " + (time - System.currentTimeMillis()) + " ms." );
ai.set( 0 );
for ( int ithread = 0; ithread < threads.length; ++ithread )
threads[ithread] = new Thread(new Runnable()
{
public void run()
{
// Thread ID
final int myNumber = ai.getAndIncrement();
// get chunk of pixels to process
final Chunk myChunk = threadChunks.get( myNumber );
computeFinalValues( myChunk.getStartPosition(), myChunk.getLoopSize(), psi, integral, processingData.getWeight(), lambda );
}
});
SimpleMultiThreading.startAndJoin( threads );
// the result from the previous iteration
//System.out.println( view + " d: " + (time - System.currentTimeMillis()) + " ms." );
}
if ( collectStatistic )
{
final AtomicInteger ai = new AtomicInteger(0);
final Thread[] threads = SimpleMultiThreading.newThreads( numThreads );
final double[][] sumMax = new double[ numThreads ][ 2 ];
for ( int ithread = 0; ithread < threads.length; ++ithread )
threads[ithread] = new Thread(new Runnable()
{
public void run()
{
// Thread ID
final int myNumber = ai.getAndIncrement();
// get chunk of pixels to process
final Chunk myChunk = threadChunks.get( myNumber );
collectStatistics( myChunk.getStartPosition(), myChunk.getLoopSize(), psi, lastIteration, sumMax[ myNumber ] );
}
});
SimpleMultiThreading.startAndJoin( threads );
// accumulate the results from the individual threads
double sumChange = 0;
double maxChange = -1;
for ( int i = 0; i < numThreads; ++i )
{
sumChange += sumMax[ i ][ 0 ];
maxChange = Math.max( maxChange, sumMax[ i ][ 1 ] );
}
IJ.log("iteration: " + iteration + " --- sum change: " + sumChange + " --- max change per pixel: " + maxChange );
}
//System.out.println( "final: " + (time - System.currentTimeMillis()) + " ms." );
}
private static final void collectStatistics( final long start, final long loopSize, final Image< FloatType > psi, final Image< FloatType > lastIteration, final double[] sumMax )
{
double sumChange = 0;
double maxChange = -1;
final Cursor< FloatType > cursorPsi = psi.createCursor();
final Cursor< FloatType > cursorLast = lastIteration.createCursor();
cursorPsi.fwd( start );
cursorLast.fwd( start );
for ( long l = 0; l < loopSize; ++l )
{
final float last = cursorLast.next().get();
final float next = cursorPsi.next().get();
final float change = Math.abs( next - last );
sumChange += change;
maxChange = Math.max( maxChange, change );
}
sumMax[ 0 ] = sumChange;
sumMax[ 1 ] = maxChange;
}
private static final void computeQuotient( final long start, final long loopSize, final Image< FloatType > psiBlurred, final LRFFT processingData )
{
final Cursor<FloatType> cursorImg = processingData.getImage().createCursor();
final Cursor<FloatType> cursorPsiBlurred = psiBlurred.createCursor();
cursorImg.fwd( start );
cursorPsiBlurred.fwd( start );
for ( long l = 0; l < loopSize; ++l )
{
cursorImg.fwd();
cursorPsiBlurred.fwd();
final float imgValue = cursorImg.getType().get();
final float psiBlurredValue = cursorPsiBlurred.getType().get();
cursorPsiBlurred.getType().set( imgValue / psiBlurredValue );
}
cursorImg.close();
cursorPsiBlurred.close();
}
private static final void computeFinalValues( final long start, final long loopSize, final Image< FloatType > psi, final Image<FloatType> integral, final Image<FloatType> weight, final double lambda )
{
final Cursor< FloatType > cursorPsi = psi.createCursor();
final Cursor< FloatType > cursorIntegral = integral.createCursor();
final Cursor< FloatType > cursorWeight = weight.createCursor();
cursorPsi.fwd( start );
cursorIntegral.fwd( start );
cursorWeight.fwd( start );
for ( long l = 0; l < loopSize; ++l )
{
cursorPsi.fwd();
cursorIntegral.fwd();
cursorWeight.fwd();
final float lastPsiValue = cursorPsi.getType().get();
float value = lastPsiValue * cursorIntegral.getType().get();
if ( value > 0 )
{
//
// perform Tikhonov regularization if desired
//
if ( lambda > 0 )
value = ( (float)( (Math.sqrt( 1.0 + 2.0*lambda*value ) - 1.0) / lambda ) );
}
else
{
value = minValue;
}
//
// get the final value and some statistics
//
float nextPsiValue;
if ( Double.isNaN( value ) )
nextPsiValue = (float)minValue;
else
nextPsiValue = (float)Math.max( minValue, value );
// compute the difference between old and new
float change = nextPsiValue - lastPsiValue;
// apply the apropriate amount
change *= cursorWeight.getType().get();
nextPsiValue = lastPsiValue + change;
// store the new value
cursorPsi.getType().set( (float)nextPsiValue );
}
}
}