/*- * #%L * Fiji distribution of ImageJ for the life sciences. * %% * Copyright (C) 2007 - 2017 Fiji developers. * %% * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as * published by the Free Software Foundation, either version 2 of the * License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public * License along with this program. If not, see * <http://www.gnu.org/licenses/gpl-2.0.html>. * #L% */ package mpicbg.spim.postprocessing.deconvolution2; import java.util.ArrayList; import java.util.concurrent.atomic.AtomicInteger; import mpicbg.imglib.algorithm.fft.FourierConvolution; import mpicbg.imglib.algorithm.mirror.MirrorImage; import mpicbg.imglib.container.ContainerFactory; import mpicbg.imglib.container.array.Array; import mpicbg.imglib.container.array.ArrayContainerFactory; import mpicbg.imglib.container.basictypecontainer.FloatAccess; import mpicbg.imglib.container.basictypecontainer.array.FloatArray; import mpicbg.imglib.container.constant.ConstantContainer; import mpicbg.imglib.container.imageplus.ImagePlusContainer; import mpicbg.imglib.container.imageplus.ImagePlusContainerFactory; import mpicbg.imglib.cursor.Cursor; import mpicbg.imglib.image.Image; import mpicbg.imglib.image.ImageFactory; import mpicbg.imglib.image.display.imagej.ImageJFunctions; import mpicbg.imglib.multithreading.SimpleMultiThreading; import mpicbg.imglib.outofbounds.OutOfBoundsStrategyValueFactory; import mpicbg.imglib.type.numeric.real.FloatType; import mpicbg.imglib.wrapper.ImgLib1; import mpicbg.imglib.wrapper.ImgLib2; import mpicbg.spim.io.IOFunctions; import net.imglib2.exception.ImgLibException; import net.imglib2.img.Img; import net.imglib2.img.imageplus.ImagePlusImg; import spim.process.cuda.CUDAFourierConvolution; public class LRFFT { public static enum PSFTYPE { OPTIMIZATION_II, OPTIMIZATION_I, EFFICIENT_BAYESIAN, INDEPENDENT }; public static CUDAFourierConvolution cuda = null; private Image<FloatType> image, weight, kernel1, kernel2; Image<FloatType> viewContribution = null; FourierConvolution<FloatType, FloatType> fftConvolution1, fftConvolution2; protected int numViews = 0; PSFTYPE iterationType; ArrayList< LRFFT > views; final boolean useBlocks, useCUDA, useCPU; final int[] blockSize, deviceList; final int device0, numDevices; final Block[] blocks; final ImageFactory< FloatType > factory; /** * Used to determine if the Convolutions already have been computed for the current iteration */ int iteration = -1; public LRFFT( final Img< net.imglib2.type.numeric.real.FloatType > image, final Img< net.imglib2.type.numeric.real.FloatType > weight, final Img< net.imglib2.type.numeric.real.FloatType > kernel, final int[] deviceList, final boolean useBlocks, final int[] blockSize ) { this( wrap( image ), wrap( weight ), wrap( kernel ), deviceList, useBlocks, blockSize ); } @SuppressWarnings("rawtypes") public static final Image< FloatType > wrap( final Img< net.imglib2.type.numeric.real.FloatType > i ) { if ( i instanceof ImagePlusImg ) { try { return ImageJFunctions.wrapFloat( ((ImagePlusImg) i).getImagePlus() ); } catch (ImgLibException e) { // TODO Auto-generated catch block e.printStackTrace(); return null; } } else { return ImgLib2.wrapFloatToImgLib1( i ); } } @SuppressWarnings("rawtypes") public static final Img< net.imglib2.type.numeric.real.FloatType > wrap( final Image< FloatType > i ) { final ContainerFactory c = i.getContainerFactory(); if ( c instanceof ImagePlusContainerFactory ) { try { return net.imglib2.img.display.imagej.ImageJFunctions.wrapFloat( ((ImagePlusContainer)i.getContainer()).getImagePlus() ); } catch (mpicbg.imglib.exception.ImgLibException e) { // TODO Auto-generated catch block e.printStackTrace(); return null; } } else { return ImgLib1.wrapFloatToImgLib2( i ); } } public LRFFT( final Image<FloatType> image, final Image<FloatType> weight, final Image<FloatType> kernel, final int[] deviceList, final boolean useBlocks, final int[] blockSize ) { this.image = image; this.kernel1 = kernel; this.weight = weight; this.deviceList = deviceList; this.device0 = deviceList[ 0 ]; this.numDevices = deviceList.length; // figure out if we need GPU and/or CPU boolean anyGPU = false; boolean anyCPU = false; for ( final int i : deviceList ) { if ( i >= 0 ) anyGPU = true; else if ( i == -1 ) anyCPU = true; } this.useCUDA = anyGPU; this.useCPU = anyCPU; if ( useBlocks ) { this.useBlocks = true; // define the blocksize so that it is one single block this.blockSize = new int[ image.getNumDimensions() ]; for ( int d = 0; d < this.blockSize.length; ++d ) this.blockSize[ d ] = blockSize[ d ]; this.blocks = Block.divideIntoBlocks( image.getDimensions(), this.blockSize, kernel.getDimensions() ); // blocksize might change during division if they were too small //this.blockSize = blockSize.clone(); IOFunctions.println( "Number of blocks: " + this.blocks.length ); this.factory = new ImageFactory< FloatType >( new FloatType(), new ArrayContainerFactory() ); } else if ( this.useCUDA ) // and no blocks, i.e. one big block { this.useBlocks = true; // define the blocksize so that it is one single block this.blockSize = new int[ image.getNumDimensions() ]; for ( int d = 0; d < this.blockSize.length; ++d ) this.blockSize[ d ] = image.getDimension( d ) + kernel.getDimension( d ) - 1; this.blocks = Block.divideIntoBlocks( image.getDimensions(), this.blockSize, kernel.getDimensions() ); this.factory = new ImageFactory< FloatType >( new FloatType(), new ArrayContainerFactory() ); } else { this.blocks = null; this.blockSize = null; this.factory = null; this.useBlocks = false; } } public LRFFT( final Image<FloatType> image, final Image<FloatType> kernel, final int[] deviceList, final boolean useBlocks, final int[] blockSize ) { this( image, new Image< FloatType > ( new ConstantContainer< FloatType >( image.getDimensions(), new FloatType( 1 ) ), new FloatType() ), kernel, deviceList, useBlocks, blockSize ); } /** * @param numViews - the number of views in the acquisition, determines the exponential of the kernel */ protected void setNumViews( final int numViews ) { this.numViews = numViews; } /** * This method is called once all views are added to the {@link LRInput} */ protected void init( final PSFTYPE iterationType, final ArrayList< LRFFT > views ) { // normalize kernel so that sum of all pixels == 1 AdjustInput.normImage( kernel1 ); this.iterationType = iterationType; this.views = views; if ( numViews == 0 ) { System.out.println( "Warning, numViews was not set." ); numViews = 1; } if ( numViews == 1 || iterationType == PSFTYPE.INDEPENDENT ) { // compute the inverted kernel (switch dimensions) this.kernel2 = computeInvertedKernel( this.kernel1 ); } else if ( iterationType == PSFTYPE.EFFICIENT_BAYESIAN ) { // compute the compound kernel P_v^compound of the efficient bayesian multi-view deconvolution // for the current view \phi_v(x_v) // // P_v^compound = P_v^{*} prod{w \in W_v} P_v^{*} \ast P_w \ast P_w^{*} // we first get P_v^{*} -> {*} refers to the inverted coordinates final Image< FloatType > tmp = computeInvertedKernel( this.kernel1.clone() ); // now for each view: w \in W_v for ( final LRFFT view : views ) { if ( view != this ) { // convolve first P_v^{*} with P_w final FourierConvolution<FloatType, FloatType> conv1 = new FourierConvolution<FloatType, FloatType>( computeInvertedKernel( this.kernel1 ), view.kernel1 ); conv1.setNumThreads(); conv1.setKeepImgFFT( false ); conv1.setImageOutOfBoundsStrategy( new OutOfBoundsStrategyValueFactory<FloatType>() ); conv1.process(); // and now convolve the result with P_w^{*} final FourierConvolution<FloatType, FloatType> conv2 = new FourierConvolution<FloatType, FloatType>( conv1.getResult(), computeInvertedKernel( view.kernel1 ) ); conv2.setNumThreads(); conv2.setKeepImgFFT( false ); conv2.setImageOutOfBoundsStrategy( new OutOfBoundsStrategyValueFactory<FloatType>() ); conv2.process(); // multiply the result with P_v^{*} yielding the compound kernel final Cursor<FloatType> cursor = tmp.createCursor(); for ( final FloatType t : ( conv2.getResult() ) ) { cursor.fwd(); cursor.getType().set( t.get() * cursor.getType().get() ); } } } // norm the compound kernel AdjustInput.normImage( tmp ); // set it as kernel2 of the deconvolution this.kernel2 = ( tmp ); } else if ( iterationType == PSFTYPE.OPTIMIZATION_I ) { // compute the simplified compound kernel P_v^compound of the efficient bayesian multi-view deconvolution // for the current view \phi_v(x_v) // // P_v^compound = P_v^{*} prod{w \in W_v} P_v^{*} \ast P_w // we first get P_v^{*} -> {*} refers to the inverted coordinates final Image< FloatType > tmp = ( this.kernel1.clone() ); // now for each view: w \in W_v for ( final LRFFT view : views ) { if ( view != this ) { final FourierConvolution<FloatType, FloatType> conv = new FourierConvolution<FloatType, FloatType>( this.kernel1, computeInvertedKernel( view.kernel1 ) ); conv.setNumThreads(); conv.setKeepImgFFT( false ); conv.setImageOutOfBoundsStrategy( new OutOfBoundsStrategyValueFactory<FloatType>() ); conv.process(); // multiply with the kernel final Cursor<FloatType> cursor = tmp.createCursor(); for ( final FloatType t : ( conv.getResult() ) ) { cursor.fwd(); cursor.getType().set( t.get() * cursor.getType().get() ); } } } // norm the compound kernel AdjustInput.normImage( tmp ); // compute the inverted kernel this.kernel2 = computeInvertedKernel( tmp ); } else //if ( iterationType == PSFTYPE.OPTIMIZATION_II ) { // compute the squared kernel and its inverse final Image< FloatType > exponentialKernel = computeExponentialKernel( this.kernel1, numViews ); // norm the squared kernel AdjustInput.normImage( exponentialKernel ); // compute the inverted squared kernel this.kernel2 = computeInvertedKernel( exponentialKernel ); } if ( useCPU ) { if ( useBlocks ) { final Image< FloatType > block = factory.createImage( blockSize ); this.fftConvolution1 = new FourierConvolution<FloatType, FloatType>( block, this.kernel1 ); this.fftConvolution1.setNumThreads(); //this.fftConvolution1.setExtendImageByKernelSize( false ); this.fftConvolution1.setKeepImgFFT( false ); this.fftConvolution2 = new FourierConvolution<FloatType, FloatType>( block, this.kernel2 ); this.fftConvolution2.setNumThreads(); //this.fftConvolution2.setExtendImageByKernelSize( false ); this.fftConvolution2.setKeepImgFFT( false ); } else { this.fftConvolution1 = new FourierConvolution<FloatType, FloatType>( this.image, this.kernel1 ); this.fftConvolution1.setNumThreads(); this.fftConvolution1.setKeepImgFFT( false ); this.fftConvolution2 = new FourierConvolution<FloatType, FloatType>( this.image, this.kernel2 ); this.fftConvolution2.setNumThreads(); this.fftConvolution2.setKeepImgFFT( false ); } } else { this.fftConvolution1 = null; this.fftConvolution2 = null; } } public static Image<FloatType> computeExponentialKernel( final Image<FloatType> kernel, final int numViews ) { final Image<FloatType> exponentialKernel = kernel.clone(); for ( final FloatType f : exponentialKernel ) f.set( pow( f.get(), numViews ) ); //IJ.log("Jusrt using numViews/2 as exponent" ); return exponentialKernel; } public static Image< FloatType > computeInvertedKernel( final Image< FloatType > kernel ) { final Image< FloatType > invKernel = kernel.clone(); for ( int d = 0; d < invKernel.getNumDimensions(); ++d ) new MirrorImage< FloatType >( invKernel, d ).process(); return invKernel; } private static float pow( final float value, final int power ) { float result = value; for ( int i = 1; i < power; ++i ) result *= value; return result; } public void setImage( final Image<FloatType> image ) { this.image = image; setCurrentIteration( -1 ); } public void setWeight( final Image<FloatType> weight ) { this.weight = weight; } public void setKernel( final Image<FloatType> kernel ) { this.kernel1 = kernel; init( iterationType, views ); setCurrentIteration( -1 ); } public Image<FloatType> getImage() { return image; } public Image<FloatType> getWeight() { return weight; } public Image<FloatType> getKernel1() { return kernel1; } public Image<FloatType> getKernel2() { return kernel2; } public void setCurrentIteration( final int i ) { this.iteration = i; } public int getCurrentIteration() { return iteration; } /** * convolves the image with kernel1 * * @param image - the image to convolve with * @return */ public Image< FloatType > convolve1( final Image< FloatType > image ) { if ( useCPU && !useCUDA ) { if ( useBlocks ) { //IJ.log( "Using CPU only on blocks ... " ); final Image< FloatType > result = image.createNewImage(); final Image< FloatType > block = factory.createImage( blockSize ); for ( int i = 0; i < blocks.length; ++i ) { /* long time = System.currentTimeMillis(); blocks[ i ].copyBlock( image, block ); System.out.println( " block " + i + ": copy " + (System.currentTimeMillis() - time) ); time = System.currentTimeMillis(); fftConvolution1.replaceImage( block ); fftConvolution1.process(); System.out.println( " block " + i + ": compute " + (System.currentTimeMillis() - time) ); time = System.currentTimeMillis(); blocks[ i ].pasteBlock( result, fftConvolution1.getResult() ); System.out.println( " block " + i + ": paste " + (System.currentTimeMillis() - time) ); */ LRFFTThreads.convolve1BlockCPU( blocks[ i ], i, image, result, block, fftConvolution1 ); } block.close(); return result; } else { //IJ.log( "Using CPU only to compute as one block ... " ); long time = System.currentTimeMillis(); final FourierConvolution<FloatType, FloatType> fftConv = fftConvolution1; fftConv.replaceImage( image ); fftConv.process(); System.out.println(" block " + iteration + ": compute " + (System.currentTimeMillis() - time) ); return fftConv.getResult(); } } else if ( useCUDA && !useCPU && numDevices == 1 ) { //if ( blocks.length > 1 ) // IJ.log( "Using CUDA only on blocks ... " ); //else // IJ.log( "Using CUDA only to compute as one block ... " ); final Image< FloatType > result = image.createNewImage(); final Image< FloatType > block = factory.createImage( blockSize ); for ( int i = 0; i < blocks.length; ++i ) { /* long time = System.currentTimeMillis(); blocks[ i ].copyBlock( image, block ); System.out.println( " block " + i + ": copy " + (System.currentTimeMillis() - time) ); // convolve block with kernel1 using CUDA time = System.currentTimeMillis(); cuda.convolution3DfftCUDAInPlace( ((FloatArray)((Array)block.getContainer()).update( null )).getCurrentStorageArray(), getCUDACoordinates( blockSize ), ((FloatArray)((Array)kernel1.getContainer()).update( null )).getCurrentStorageArray(), getCUDACoordinates( kernel1.getDimensions() ), device0 ); System.out.println( " block " + i + ": compute " + (System.currentTimeMillis() - time) ); time = System.currentTimeMillis(); blocks[ i ].pasteBlock( result, block ); System.out.println( " block " + i + ": paste " + (System.currentTimeMillis() - time) ); */ LRFFTThreads.convolve1BlockCUDA( blocks[ i ], i, device0, image, result, block, kernel1, blockSize ); } block.close(); return result; } else { // this implies useBlocks, otherwise we cannot combine several devices //IJ.log( "Using CUDA & CPU on blocks ... " ); final Image< FloatType > result = image.createNewImage(); final AtomicInteger ai = new AtomicInteger(); final Thread[] threads = SimpleMultiThreading.newThreads( deviceList.length ); for ( int i = 0; i < deviceList.length; ++i ) { if ( deviceList[ i ] == -1 ) threads[ i ] = LRFFTThreads.getCPUThread1( ai, blocks, blockSize, factory, image, result, fftConvolution1 ); else threads[ i ] = LRFFTThreads.getCUDAThread1( ai, blocks, blockSize, factory, image, result, deviceList[ i ], kernel1 ); } SimpleMultiThreading.startAndJoin( threads ); return result; } } final public static Image<FloatType> createImageFromArray( final float[] data, final int[] dim ) { final FloatAccess access = new FloatArray( data ); final Array<FloatType, FloatAccess> array = new Array<FloatType, FloatAccess>(new ArrayContainerFactory(), access, dim, 1 ); // create a Type that is linked to the container final FloatType linkedType = new FloatType( array ); // pass it to the DirectAccessContainer array.setLinkedType( linkedType ); return new Image<FloatType>(array, new FloatType()); } /** * convolves the image with kernel2 (inverted kernel1) * * @param image - the image to convolve with * @return */ public Image< FloatType > convolve2( final Image< FloatType > image ) { if ( useCPU && !useCUDA ) { if ( useBlocks ) { final Image< FloatType > result = image.createNewImage(); final Image< FloatType > block = factory.createImage( blockSize ); for ( int i = 0; i < blocks.length; ++i ) { /* blocks[ i ].copyBlock( image, block ); fftConvolution2.replaceImage( block ); fftConvolution2.process(); blocks[ i ].pasteBlock( result, fftConvolution2.getResult() ); */ LRFFTThreads.convolve2BlockCPU( blocks[ i ], image, result, block, fftConvolution2 ); } block.close(); return result; } else { final FourierConvolution<FloatType, FloatType> fftConv = fftConvolution2; fftConv.replaceImage( image ); fftConv.process(); return fftConv.getResult(); } } else if ( useCUDA && !useCPU && numDevices == 1 ) { final Image< FloatType > result = image.createNewImage(); final Image< FloatType > block = factory.createImage( blockSize ); for ( int i = 0; i < blocks.length; ++i ) { /* blocks[ i ].copyBlock( image, block ); // convolve block with kernel2 using CUDA cuda.convolution3DfftCUDAInPlace( ((FloatArray)((Array)block.getContainer()).update( null )).getCurrentStorageArray(), getCUDACoordinates( blockSize ), ((FloatArray)((Array)kernel2.getContainer()).update( null )).getCurrentStorageArray(), getCUDACoordinates( kernel2.getDimensions() ), device0 ); blocks[ i ].pasteBlock( result, block ); */ LRFFTThreads.convolve2BlockCUDA( blocks[ i ], device0, image, result, block, kernel2, blockSize ); } block.close(); return result; } else { final Image< FloatType > result = image.createNewImage(); final AtomicInteger ai = new AtomicInteger(); final Thread[] threads = SimpleMultiThreading.newThreads( deviceList.length ); for ( int i = 0; i < deviceList.length; ++i ) { if ( deviceList[ i ] == -1 ) threads[ i ] = LRFFTThreads.getCPUThread2( ai, blocks, blockSize, factory, image, result, fftConvolution2 ); else threads[ i ] = LRFFTThreads.getCUDAThread2( ai, blocks, blockSize, factory, image, result, deviceList[ i ], kernel2 ); } SimpleMultiThreading.startAndJoin( threads ); return result; } } @Override public LRFFT clone() { final LRFFT viewClone = new LRFFT( this.image.clone(), this.weight.clone(), this.kernel1.clone(), deviceList, useBlocks, blockSize ); viewClone.numViews = numViews; viewClone.iterationType = iterationType; viewClone.views = views; viewClone.iteration = iteration; if ( this.kernel2 != null ) viewClone.kernel2 = kernel2.clone(); if ( this.viewContribution != null ) viewClone.viewContribution = this.viewContribution.clone(); if ( this.fftConvolution1 != null ) { viewClone.fftConvolution1 = new FourierConvolution<FloatType, FloatType>( fftConvolution1.getImage(), fftConvolution1.getKernel() ); viewClone.fftConvolution1.process(); } if ( this.fftConvolution2 != null ) { viewClone.fftConvolution2 = new FourierConvolution<FloatType, FloatType>( fftConvolution2.getImage(), fftConvolution2.getKernel() ); viewClone.fftConvolution2.process(); } return viewClone; } }