/*-
* #%L
* Fiji distribution of ImageJ for the life sciences.
* %%
* Copyright (C) 2007 - 2017 Fiji developers.
* %%
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as
* published by the Free Software Foundation, either version 2 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public
* License along with this program. If not, see
* <http://www.gnu.org/licenses/gpl-2.0.html>.
* #L%
*/
package mpicbg.spim.postprocessing.deconvolution2;
import java.util.ArrayList;
import java.util.concurrent.atomic.AtomicInteger;
import mpicbg.imglib.algorithm.fft.FourierConvolution;
import mpicbg.imglib.algorithm.mirror.MirrorImage;
import mpicbg.imglib.container.ContainerFactory;
import mpicbg.imglib.container.array.Array;
import mpicbg.imglib.container.array.ArrayContainerFactory;
import mpicbg.imglib.container.basictypecontainer.FloatAccess;
import mpicbg.imglib.container.basictypecontainer.array.FloatArray;
import mpicbg.imglib.container.constant.ConstantContainer;
import mpicbg.imglib.container.imageplus.ImagePlusContainer;
import mpicbg.imglib.container.imageplus.ImagePlusContainerFactory;
import mpicbg.imglib.cursor.Cursor;
import mpicbg.imglib.image.Image;
import mpicbg.imglib.image.ImageFactory;
import mpicbg.imglib.image.display.imagej.ImageJFunctions;
import mpicbg.imglib.multithreading.SimpleMultiThreading;
import mpicbg.imglib.outofbounds.OutOfBoundsStrategyValueFactory;
import mpicbg.imglib.type.numeric.real.FloatType;
import mpicbg.imglib.wrapper.ImgLib1;
import mpicbg.imglib.wrapper.ImgLib2;
import mpicbg.spim.io.IOFunctions;
import net.imglib2.exception.ImgLibException;
import net.imglib2.img.Img;
import net.imglib2.img.imageplus.ImagePlusImg;
import spim.process.cuda.CUDAFourierConvolution;
public class LRFFT
{
public static enum PSFTYPE { OPTIMIZATION_II, OPTIMIZATION_I, EFFICIENT_BAYESIAN, INDEPENDENT };
public static CUDAFourierConvolution cuda = null;
private Image<FloatType> image, weight, kernel1, kernel2;
Image<FloatType> viewContribution = null;
FourierConvolution<FloatType, FloatType> fftConvolution1, fftConvolution2;
protected int numViews = 0;
PSFTYPE iterationType;
ArrayList< LRFFT > views;
final boolean useBlocks, useCUDA, useCPU;
final int[] blockSize, deviceList;
final int device0, numDevices;
final Block[] blocks;
final ImageFactory< FloatType > factory;
/**
* Used to determine if the Convolutions already have been computed for the current iteration
*/
int iteration = -1;
public LRFFT(
final Img< net.imglib2.type.numeric.real.FloatType > image,
final Img< net.imglib2.type.numeric.real.FloatType > weight,
final Img< net.imglib2.type.numeric.real.FloatType > kernel,
final int[] deviceList, final boolean useBlocks, final int[] blockSize )
{
this( wrap( image ), wrap( weight ), wrap( kernel ), deviceList, useBlocks, blockSize );
}
@SuppressWarnings("rawtypes")
public static final Image< FloatType > wrap( final Img< net.imglib2.type.numeric.real.FloatType > i )
{
if ( i instanceof ImagePlusImg )
{
try
{
return ImageJFunctions.wrapFloat( ((ImagePlusImg) i).getImagePlus() );
}
catch (ImgLibException e)
{
// TODO Auto-generated catch block
e.printStackTrace();
return null;
}
}
else
{
return ImgLib2.wrapFloatToImgLib1( i );
}
}
@SuppressWarnings("rawtypes")
public static final Img< net.imglib2.type.numeric.real.FloatType > wrap( final Image< FloatType > i )
{
final ContainerFactory c = i.getContainerFactory();
if ( c instanceof ImagePlusContainerFactory )
{
try
{
return net.imglib2.img.display.imagej.ImageJFunctions.wrapFloat( ((ImagePlusContainer)i.getContainer()).getImagePlus() );
}
catch (mpicbg.imglib.exception.ImgLibException e) {
// TODO Auto-generated catch block
e.printStackTrace();
return null;
}
}
else
{
return ImgLib1.wrapFloatToImgLib2( i );
}
}
public LRFFT(
final Image<FloatType> image,
final Image<FloatType> weight,
final Image<FloatType> kernel,
final int[] deviceList, final boolean useBlocks, final int[] blockSize )
{
this.image = image;
this.kernel1 = kernel;
this.weight = weight;
this.deviceList = deviceList;
this.device0 = deviceList[ 0 ];
this.numDevices = deviceList.length;
// figure out if we need GPU and/or CPU
boolean anyGPU = false;
boolean anyCPU = false;
for ( final int i : deviceList )
{
if ( i >= 0 )
anyGPU = true;
else if ( i == -1 )
anyCPU = true;
}
this.useCUDA = anyGPU;
this.useCPU = anyCPU;
if ( useBlocks )
{
this.useBlocks = true;
// define the blocksize so that it is one single block
this.blockSize = new int[ image.getNumDimensions() ];
for ( int d = 0; d < this.blockSize.length; ++d )
this.blockSize[ d ] = blockSize[ d ];
this.blocks = Block.divideIntoBlocks( image.getDimensions(), this.blockSize, kernel.getDimensions() );
// blocksize might change during division if they were too small
//this.blockSize = blockSize.clone();
IOFunctions.println( "Number of blocks: " + this.blocks.length );
this.factory = new ImageFactory< FloatType >( new FloatType(), new ArrayContainerFactory() );
}
else if ( this.useCUDA ) // and no blocks, i.e. one big block
{
this.useBlocks = true;
// define the blocksize so that it is one single block
this.blockSize = new int[ image.getNumDimensions() ];
for ( int d = 0; d < this.blockSize.length; ++d )
this.blockSize[ d ] = image.getDimension( d ) + kernel.getDimension( d ) - 1;
this.blocks = Block.divideIntoBlocks( image.getDimensions(), this.blockSize, kernel.getDimensions() );
this.factory = new ImageFactory< FloatType >( new FloatType(), new ArrayContainerFactory() );
}
else
{
this.blocks = null;
this.blockSize = null;
this.factory = null;
this.useBlocks = false;
}
}
public LRFFT( final Image<FloatType> image, final Image<FloatType> kernel, final int[] deviceList, final boolean useBlocks, final int[] blockSize )
{
this( image, new Image< FloatType > ( new ConstantContainer< FloatType >( image.getDimensions(), new FloatType( 1 ) ), new FloatType() ), kernel, deviceList, useBlocks, blockSize );
}
/**
* @param numViews - the number of views in the acquisition, determines the exponential of the kernel
*/
protected void setNumViews( final int numViews ) { this.numViews = numViews; }
/**
* This method is called once all views are added to the {@link LRInput}
*/
protected void init( final PSFTYPE iterationType, final ArrayList< LRFFT > views )
{
// normalize kernel so that sum of all pixels == 1
AdjustInput.normImage( kernel1 );
this.iterationType = iterationType;
this.views = views;
if ( numViews == 0 )
{
System.out.println( "Warning, numViews was not set." );
numViews = 1;
}
if ( numViews == 1 || iterationType == PSFTYPE.INDEPENDENT )
{
// compute the inverted kernel (switch dimensions)
this.kernel2 = computeInvertedKernel( this.kernel1 );
}
else if ( iterationType == PSFTYPE.EFFICIENT_BAYESIAN )
{
// compute the compound kernel P_v^compound of the efficient bayesian multi-view deconvolution
// for the current view \phi_v(x_v)
//
// P_v^compound = P_v^{*} prod{w \in W_v} P_v^{*} \ast P_w \ast P_w^{*}
// we first get P_v^{*} -> {*} refers to the inverted coordinates
final Image< FloatType > tmp = computeInvertedKernel( this.kernel1.clone() );
// now for each view: w \in W_v
for ( final LRFFT view : views )
{
if ( view != this )
{
// convolve first P_v^{*} with P_w
final FourierConvolution<FloatType, FloatType> conv1 = new FourierConvolution<FloatType, FloatType>( computeInvertedKernel( this.kernel1 ), view.kernel1 );
conv1.setNumThreads();
conv1.setKeepImgFFT( false );
conv1.setImageOutOfBoundsStrategy( new OutOfBoundsStrategyValueFactory<FloatType>() );
conv1.process();
// and now convolve the result with P_w^{*}
final FourierConvolution<FloatType, FloatType> conv2 = new FourierConvolution<FloatType, FloatType>( conv1.getResult(), computeInvertedKernel( view.kernel1 ) );
conv2.setNumThreads();
conv2.setKeepImgFFT( false );
conv2.setImageOutOfBoundsStrategy( new OutOfBoundsStrategyValueFactory<FloatType>() );
conv2.process();
// multiply the result with P_v^{*} yielding the compound kernel
final Cursor<FloatType> cursor = tmp.createCursor();
for ( final FloatType t : ( conv2.getResult() ) )
{
cursor.fwd();
cursor.getType().set( t.get() * cursor.getType().get() );
}
}
}
// norm the compound kernel
AdjustInput.normImage( tmp );
// set it as kernel2 of the deconvolution
this.kernel2 = ( tmp );
}
else if ( iterationType == PSFTYPE.OPTIMIZATION_I )
{
// compute the simplified compound kernel P_v^compound of the efficient bayesian multi-view deconvolution
// for the current view \phi_v(x_v)
//
// P_v^compound = P_v^{*} prod{w \in W_v} P_v^{*} \ast P_w
// we first get P_v^{*} -> {*} refers to the inverted coordinates
final Image< FloatType > tmp = ( this.kernel1.clone() );
// now for each view: w \in W_v
for ( final LRFFT view : views )
{
if ( view != this )
{
final FourierConvolution<FloatType, FloatType> conv = new FourierConvolution<FloatType, FloatType>( this.kernel1, computeInvertedKernel( view.kernel1 ) );
conv.setNumThreads();
conv.setKeepImgFFT( false );
conv.setImageOutOfBoundsStrategy( new OutOfBoundsStrategyValueFactory<FloatType>() );
conv.process();
// multiply with the kernel
final Cursor<FloatType> cursor = tmp.createCursor();
for ( final FloatType t : ( conv.getResult() ) )
{
cursor.fwd();
cursor.getType().set( t.get() * cursor.getType().get() );
}
}
}
// norm the compound kernel
AdjustInput.normImage( tmp );
// compute the inverted kernel
this.kernel2 = computeInvertedKernel( tmp );
}
else //if ( iterationType == PSFTYPE.OPTIMIZATION_II )
{
// compute the squared kernel and its inverse
final Image< FloatType > exponentialKernel = computeExponentialKernel( this.kernel1, numViews );
// norm the squared kernel
AdjustInput.normImage( exponentialKernel );
// compute the inverted squared kernel
this.kernel2 = computeInvertedKernel( exponentialKernel );
}
if ( useCPU )
{
if ( useBlocks )
{
final Image< FloatType > block = factory.createImage( blockSize );
this.fftConvolution1 = new FourierConvolution<FloatType, FloatType>( block, this.kernel1 );
this.fftConvolution1.setNumThreads();
//this.fftConvolution1.setExtendImageByKernelSize( false );
this.fftConvolution1.setKeepImgFFT( false );
this.fftConvolution2 = new FourierConvolution<FloatType, FloatType>( block, this.kernel2 );
this.fftConvolution2.setNumThreads();
//this.fftConvolution2.setExtendImageByKernelSize( false );
this.fftConvolution2.setKeepImgFFT( false );
}
else
{
this.fftConvolution1 = new FourierConvolution<FloatType, FloatType>( this.image, this.kernel1 );
this.fftConvolution1.setNumThreads();
this.fftConvolution1.setKeepImgFFT( false );
this.fftConvolution2 = new FourierConvolution<FloatType, FloatType>( this.image, this.kernel2 );
this.fftConvolution2.setNumThreads();
this.fftConvolution2.setKeepImgFFT( false );
}
}
else
{
this.fftConvolution1 = null;
this.fftConvolution2 = null;
}
}
public static Image<FloatType> computeExponentialKernel( final Image<FloatType> kernel, final int numViews )
{
final Image<FloatType> exponentialKernel = kernel.clone();
for ( final FloatType f : exponentialKernel )
f.set( pow( f.get(), numViews ) );
//IJ.log("Jusrt using numViews/2 as exponent" );
return exponentialKernel;
}
public static Image< FloatType > computeInvertedKernel( final Image< FloatType > kernel )
{
final Image< FloatType > invKernel = kernel.clone();
for ( int d = 0; d < invKernel.getNumDimensions(); ++d )
new MirrorImage< FloatType >( invKernel, d ).process();
return invKernel;
}
private static float pow( final float value, final int power )
{
float result = value;
for ( int i = 1; i < power; ++i )
result *= value;
return result;
}
public void setImage( final Image<FloatType> image )
{
this.image = image;
setCurrentIteration( -1 );
}
public void setWeight( final Image<FloatType> weight ) { this.weight = weight; }
public void setKernel( final Image<FloatType> kernel )
{
this.kernel1 = kernel;
init( iterationType, views );
setCurrentIteration( -1 );
}
public Image<FloatType> getImage() { return image; }
public Image<FloatType> getWeight() { return weight; }
public Image<FloatType> getKernel1() { return kernel1; }
public Image<FloatType> getKernel2() { return kernel2; }
public void setCurrentIteration( final int i ) { this.iteration = i; }
public int getCurrentIteration() { return iteration; }
/**
* convolves the image with kernel1
*
* @param image - the image to convolve with
* @return
*/
public Image< FloatType > convolve1( final Image< FloatType > image )
{
if ( useCPU && !useCUDA )
{
if ( useBlocks )
{
//IJ.log( "Using CPU only on blocks ... " );
final Image< FloatType > result = image.createNewImage();
final Image< FloatType > block = factory.createImage( blockSize );
for ( int i = 0; i < blocks.length; ++i )
{
/*
long time = System.currentTimeMillis();
blocks[ i ].copyBlock( image, block );
System.out.println( " block " + i + ": copy " + (System.currentTimeMillis() - time) );
time = System.currentTimeMillis();
fftConvolution1.replaceImage( block );
fftConvolution1.process();
System.out.println( " block " + i + ": compute " + (System.currentTimeMillis() - time) );
time = System.currentTimeMillis();
blocks[ i ].pasteBlock( result, fftConvolution1.getResult() );
System.out.println( " block " + i + ": paste " + (System.currentTimeMillis() - time) );
*/
LRFFTThreads.convolve1BlockCPU( blocks[ i ], i, image, result, block, fftConvolution1 );
}
block.close();
return result;
}
else
{
//IJ.log( "Using CPU only to compute as one block ... " );
long time = System.currentTimeMillis();
final FourierConvolution<FloatType, FloatType> fftConv = fftConvolution1;
fftConv.replaceImage( image );
fftConv.process();
System.out.println(" block " + iteration + ": compute " + (System.currentTimeMillis() - time) );
return fftConv.getResult();
}
}
else if ( useCUDA && !useCPU && numDevices == 1 )
{
//if ( blocks.length > 1 )
// IJ.log( "Using CUDA only on blocks ... " );
//else
// IJ.log( "Using CUDA only to compute as one block ... " );
final Image< FloatType > result = image.createNewImage();
final Image< FloatType > block = factory.createImage( blockSize );
for ( int i = 0; i < blocks.length; ++i )
{
/*
long time = System.currentTimeMillis();
blocks[ i ].copyBlock( image, block );
System.out.println( " block " + i + ": copy " + (System.currentTimeMillis() - time) );
// convolve block with kernel1 using CUDA
time = System.currentTimeMillis();
cuda.convolution3DfftCUDAInPlace( ((FloatArray)((Array)block.getContainer()).update( null )).getCurrentStorageArray(), getCUDACoordinates( blockSize ),
((FloatArray)((Array)kernel1.getContainer()).update( null )).getCurrentStorageArray(), getCUDACoordinates( kernel1.getDimensions() ), device0 );
System.out.println( " block " + i + ": compute " + (System.currentTimeMillis() - time) );
time = System.currentTimeMillis();
blocks[ i ].pasteBlock( result, block );
System.out.println( " block " + i + ": paste " + (System.currentTimeMillis() - time) );
*/
LRFFTThreads.convolve1BlockCUDA( blocks[ i ], i, device0, image, result, block, kernel1, blockSize );
}
block.close();
return result;
}
else
{
// this implies useBlocks, otherwise we cannot combine several devices
//IJ.log( "Using CUDA & CPU on blocks ... " );
final Image< FloatType > result = image.createNewImage();
final AtomicInteger ai = new AtomicInteger();
final Thread[] threads = SimpleMultiThreading.newThreads( deviceList.length );
for ( int i = 0; i < deviceList.length; ++i )
{
if ( deviceList[ i ] == -1 )
threads[ i ] = LRFFTThreads.getCPUThread1( ai, blocks, blockSize, factory, image, result, fftConvolution1 );
else
threads[ i ] = LRFFTThreads.getCUDAThread1( ai, blocks, blockSize, factory, image, result, deviceList[ i ], kernel1 );
}
SimpleMultiThreading.startAndJoin( threads );
return result;
}
}
final public static Image<FloatType> createImageFromArray( final float[] data, final int[] dim )
{
final FloatAccess access = new FloatArray( data );
final Array<FloatType, FloatAccess> array =
new Array<FloatType, FloatAccess>(new ArrayContainerFactory(), access, dim, 1 );
// create a Type that is linked to the container
final FloatType linkedType = new FloatType( array );
// pass it to the DirectAccessContainer
array.setLinkedType( linkedType );
return new Image<FloatType>(array, new FloatType());
}
/**
* convolves the image with kernel2 (inverted kernel1)
*
* @param image - the image to convolve with
* @return
*/
public Image< FloatType > convolve2( final Image< FloatType > image )
{
if ( useCPU && !useCUDA )
{
if ( useBlocks )
{
final Image< FloatType > result = image.createNewImage();
final Image< FloatType > block = factory.createImage( blockSize );
for ( int i = 0; i < blocks.length; ++i )
{
/*
blocks[ i ].copyBlock( image, block );
fftConvolution2.replaceImage( block );
fftConvolution2.process();
blocks[ i ].pasteBlock( result, fftConvolution2.getResult() );
*/
LRFFTThreads.convolve2BlockCPU( blocks[ i ], image, result, block, fftConvolution2 );
}
block.close();
return result;
}
else
{
final FourierConvolution<FloatType, FloatType> fftConv = fftConvolution2;
fftConv.replaceImage( image );
fftConv.process();
return fftConv.getResult();
}
}
else if ( useCUDA && !useCPU && numDevices == 1 )
{
final Image< FloatType > result = image.createNewImage();
final Image< FloatType > block = factory.createImage( blockSize );
for ( int i = 0; i < blocks.length; ++i )
{
/*
blocks[ i ].copyBlock( image, block );
// convolve block with kernel2 using CUDA
cuda.convolution3DfftCUDAInPlace( ((FloatArray)((Array)block.getContainer()).update( null )).getCurrentStorageArray(), getCUDACoordinates( blockSize ),
((FloatArray)((Array)kernel2.getContainer()).update( null )).getCurrentStorageArray(), getCUDACoordinates( kernel2.getDimensions() ), device0 );
blocks[ i ].pasteBlock( result, block );
*/
LRFFTThreads.convolve2BlockCUDA( blocks[ i ], device0, image, result, block, kernel2, blockSize );
}
block.close();
return result;
}
else
{
final Image< FloatType > result = image.createNewImage();
final AtomicInteger ai = new AtomicInteger();
final Thread[] threads = SimpleMultiThreading.newThreads( deviceList.length );
for ( int i = 0; i < deviceList.length; ++i )
{
if ( deviceList[ i ] == -1 )
threads[ i ] = LRFFTThreads.getCPUThread2( ai, blocks, blockSize, factory, image, result, fftConvolution2 );
else
threads[ i ] = LRFFTThreads.getCUDAThread2( ai, blocks, blockSize, factory, image, result, deviceList[ i ], kernel2 );
}
SimpleMultiThreading.startAndJoin( threads );
return result;
}
}
@Override
public LRFFT clone()
{
final LRFFT viewClone = new LRFFT( this.image.clone(), this.weight.clone(), this.kernel1.clone(), deviceList, useBlocks, blockSize );
viewClone.numViews = numViews;
viewClone.iterationType = iterationType;
viewClone.views = views;
viewClone.iteration = iteration;
if ( this.kernel2 != null )
viewClone.kernel2 = kernel2.clone();
if ( this.viewContribution != null )
viewClone.viewContribution = this.viewContribution.clone();
if ( this.fftConvolution1 != null )
{
viewClone.fftConvolution1 = new FourierConvolution<FloatType, FloatType>( fftConvolution1.getImage(), fftConvolution1.getKernel() );
viewClone.fftConvolution1.process();
}
if ( this.fftConvolution2 != null )
{
viewClone.fftConvolution2 = new FourierConvolution<FloatType, FloatType>( fftConvolution2.getImage(), fftConvolution2.getKernel() );
viewClone.fftConvolution2.process();
}
return viewClone;
}
}