/*- * #%L * Fiji distribution of ImageJ for the life sciences. * %% * Copyright (C) 2007 - 2017 Fiji developers. * %% * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as * published by the Free Software Foundation, either version 2 of the * License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public * License along with this program. If not, see * <http://www.gnu.org/licenses/gpl-2.0.html>. * #L% */ package mpicbg.spim.postprocessing.deconvolution; import ij.IJ; import java.math.BigDecimal; import java.math.MathContext; import java.util.ArrayList; import java.util.Date; import java.util.concurrent.atomic.AtomicInteger; import mpicbg.imglib.algorithm.fft.FourierConvolution; import mpicbg.imglib.cursor.Cursor; import mpicbg.imglib.image.Image; import mpicbg.imglib.image.display.imagej.ImageJFunctions; import mpicbg.imglib.multithreading.SimpleMultiThreading; import mpicbg.imglib.type.numeric.real.FloatType; import mpicbg.util.RealSum; public class LucyRichardsonMultiViewDeconvolution { public static boolean debug = false; public static int debugInterval = 10; public static Image<FloatType> lucyRichardsonMultiView( final ArrayList<LucyRichardsonFFT> data, final int minIterations, final int maxIterations, final boolean multiplicative, final double lambda, final int numThreads ) { //final int numThreads = Threads.numThreads(); final int numViews = data.size(); //final long numPixels = data.get( 0 ).getImage().getNumPixels(); //final double minValue = (1.0 / ( 10000.0 * numPixels ) ); final double minValue = 0.0001; final Image<FloatType> psi = data.get( 0 ).getImage().createNewImage( "psi (deconvolved image)" ); // // for every image the integral of all pixel values // final AtomicInteger ai = new AtomicInteger(0); Thread[] threads = SimpleMultiThreading.newThreads( numThreads ); for ( int ithread = 0; ithread < threads.length; ++ithread ) threads[ithread] = new Thread(new Runnable() { public void run() { final int myNumber = ai.getAndIncrement(); for ( int i = 0; i < data.size(); ++i ) if ( i%numThreads == myNumber ) { //IJ.log( new Date( System.currentTimeMillis() ) + " Norming image " + (i+1) ); //normImage( data.get( i ).getImage() ); IJ.log( new Date( System.currentTimeMillis() ) + " Norming kernel " + (i+1) ); normImage( data.get( i ).getKernel() ); } } }); //IJ.log( "NumThreads: " + threads.length ); SimpleMultiThreading.startAndJoin( threads ); // the overlapping area has the same energy final double avg = normAllImages( data ); //for ( LucyRichardsonFFT img : data ) //{ // img.getImage().getDisplay().setMinMax(); // ImageJFunctions.copyToImagePlus( img.getImage() ).show(); //} //SimpleMultiThreading.threadHaltUnClean(); IJ.log( "Average intensity in overlapping area: " + avg ); // // the real data image psi is initialized with the average // final Cursor<FloatType> cursorPsiGlobal = psi.createCursor(); //final float avgFloat = (float)( 1.0 / (double)numPixels ); final float avgFloat = (float)avg; while ( cursorPsiGlobal.hasNext() ) { cursorPsiGlobal.fwd(); cursorPsiGlobal.getType().set( avgFloat ); } cursorPsiGlobal.reset(); final Image<FloatType> nextPsi = psi.createNewImage(); final Cursor<FloatType> cursorNextPsiGlobal = nextPsi.createCursor(); // // Start iteration // double sumChange = 0; int i = 0; do { IJ.log( "iteration: " + i++ + " (" + new Date(System.currentTimeMillis()) + ")" ); // // Set next psi to 1, we then multiply the results from the different views // cursorNextPsiGlobal.reset(); while ( cursorNextPsiGlobal.hasNext() ) { cursorNextPsiGlobal.fwd(); cursorNextPsiGlobal.getType().set( 1 ); } // // For each view we have to divide the image by the blurred image and convolve with the kernel // Then we multiply the result to psi to get the new estimate of psi // ai.set( 0 ); threads = SimpleMultiThreading.newThreads( numThreads ); for ( int ithread = 0; ithread < threads.length; ++ithread ) threads[ithread] = new Thread(new Runnable() { public void run() { final int myNumber = ai.getAndIncrement(); for ( int view = 0; view < numViews; ++view ) if ( view%numThreads == myNumber ) { final LucyRichardsonFFT processingData = data.get( view ); // convolve psi (current guess of the image) with the PSF of the current view final FourierConvolution<FloatType, FloatType> fftConvolution = processingData.getFFTConvolution(); fftConvolution.replaceImage( psi ); fftConvolution.process(); final Image<FloatType> psiBlurred = fftConvolution.getResult(); //psiBlurred.getDisplay().setMinMax(); //psiBlurred.setName( "psiBlurred " + view ); //ImageJFunctions.copyToImagePlus( psiBlurred ).show(); // compute quotient img/psiBlurred final Cursor<FloatType> cursorImg = processingData.getImage().createCursor(); final Cursor<FloatType> cursorPsiBlurred = psiBlurred.createCursor(); while ( cursorImg.hasNext() ) { cursorImg.fwd(); cursorPsiBlurred.fwd(); final float imgValue = cursorImg.getType().get(); final float psiBlurredValue = cursorPsiBlurred.getType().get(); cursorPsiBlurred.getType().set( imgValue / psiBlurredValue ); } cursorImg.close(); cursorPsiBlurred.close(); //psiBlurred.getDisplay().setMinMax(); //psiBlurred.setName( "img/psiBlurred " + view ); //ImageJFunctions.copyToImagePlus( psiBlurred ).show(); // blur the residuals image with the kernel fftConvolution.replaceImage( psiBlurred ); fftConvolution.process(); processingData.setViewContribution( fftConvolution.getResult() ); //fftConvolution.getResult().getDisplay().setMinMax(); //fftConvolution.getResult().setName( "conv(img/psiBlurred) " + view ); //ImageJFunctions.copyToImagePlus( fftConvolution.getResult() ).show(); // close unecessary psiBlurred psiBlurred.close(); } } }); SimpleMultiThreading.startAndJoin( threads ); // // multiply residualsBlurred with nextPsi and compute the n-root of each pixel ( where n is the number of views ) // this cannot be done in multithreaded as it would collide // final ArrayList<Cursor<FloatType>> blurredResidualsCursors = new ArrayList<Cursor<FloatType>>(); //final ArrayList<Cursor<FloatType>> imageCursors = new ArrayList<Cursor<FloatType>>(); final ArrayList<Cursor<FloatType>> weightCursors = new ArrayList<Cursor<FloatType>>(); for ( int view = 0; view < numViews; ++view ) { blurredResidualsCursors.add( data.get( view ).getViewContribution().createCursor() ); //imageCursors.add( data.get( view ).getImage().createCursor() ); if ( data.get( view ).getWeight() != null ) weightCursors.add( data.get( view ).getWeight().createCursor() ); //data.get( view ).getWeight().getDisplay().setMinMax(); //data.get( view ).getWeight().setName( "weight " + view ); //ImageJFunctions.copyToImagePlus( data.get( view ).getWeight() ).show(); } cursorNextPsiGlobal.reset(); cursorPsiGlobal.reset(); while ( cursorNextPsiGlobal.hasNext() ) { cursorNextPsiGlobal.fwd(); double value = cursorNextPsiGlobal.getType().get(); if ( !multiplicative ) value = 0; double num = 0; if ( weightCursors.size() > 0 ) { for ( int h = 0; h < numViews; ++h ) { final Cursor<FloatType> cursorResidualsBlurred = blurredResidualsCursors.get( h ); //final Cursor<FloatType> cursorImage = imageCursors.get( h ); final Cursor<FloatType> cursorWeight = weightCursors.get( h ); cursorResidualsBlurred.fwd(); //cursorImage.fwd(); cursorWeight.fwd(); final float weight = cursorWeight.getType().get(); if ( weight > 0 ) { if ( multiplicative ) value *= Math.pow( cursorResidualsBlurred.getType().get(), weight ); else value += cursorResidualsBlurred.getType().get() * weight; num += weight; } } } else { for ( int h = 0; h < numViews; ++h ) { final Cursor<FloatType> cursorResidualsBlurred = blurredResidualsCursors.get( h ); //final Cursor<FloatType> cursorImage = imageCursors.get( h ); cursorResidualsBlurred.fwd(); //cursorImage.fwd(); value *= cursorResidualsBlurred.getType().get(); num++; } } cursorPsiGlobal.fwd(); if ( num > 0 ) { if ( multiplicative ) value = (double)cursorPsiGlobal.getType().get() * Math.pow( value, 1.0/num ); else value = (double)cursorPsiGlobal.getType().get() * cursorNextPsiGlobal.getType().get() * value/num; } else { // maybe that works ... value = minValue; //(double)cursorPsiGlobal.getType().get() * cursorNextPsiGlobal.getType().get(); } cursorNextPsiGlobal.getType().set( (float)value ); } //nextPsi.getDisplay().setMinMax(); //nextPsi.setName( "nextPsi " ); //ImageJFunctions.copyToImagePlus( nextPsi ).show(); //SimpleMultiThreading.threadHaltUnClean(); for ( final Cursor<FloatType> cursorResidualsBlurred : blurredResidualsCursors ) cursorResidualsBlurred.close(); //for ( final Cursor<FloatType> cursorImage : imageCursors ) // cursorImage.close(); for ( final Cursor<FloatType> cursorWeight : weightCursors ) cursorWeight.close(); // // perform Tikhonov regularization if desired // if ( lambda > 0 ) { cursorNextPsiGlobal.reset(); while ( cursorNextPsiGlobal.hasNext() ) { cursorNextPsiGlobal.fwd(); final float f = cursorNextPsiGlobal.getType().get(); final float reg = (float)( (Math.sqrt( 1.0 + 2.0*lambda*f ) - 1.0) / lambda ); cursorNextPsiGlobal.getType().set( reg ); } } // // Update psi for next iteration // cursorPsiGlobal.reset(); cursorNextPsiGlobal.reset(); sumChange = 0; double maxChange = -1; while ( cursorNextPsiGlobal.hasNext() ) { cursorPsiGlobal.fwd(); cursorNextPsiGlobal.fwd(); final float lastPsiValue = cursorPsiGlobal.getType().get(); final float nextPsiValue; if ( Float.isNaN( cursorNextPsiGlobal.getType().get() ) ) nextPsiValue = (float)minValue; else nextPsiValue = (float)Math.max( minValue, cursorNextPsiGlobal.getType().get() ); cursorPsiGlobal.getType().set( nextPsiValue ); final float change = Math.abs( lastPsiValue - nextPsiValue ); sumChange += change; maxChange = Math.max( maxChange, change ); } IJ.log("------------------------------------------------"); IJ.log(" Change: " + sumChange ); IJ.log(" Max Change per Pixel: " + maxChange ); IJ.log("------------------------------------------------"); System.out.println( i + "\t" + sumChange + "\t" + maxChange ); if ( debug && i % debugInterval == 0 ) { Image<FloatType> psiCopy = psi.clone(); //ViewDataBeads.normalizeImage( psiCopy ); psiCopy.setName( "Iteration " + i + " l=" + lambda ); psiCopy.getDisplay().setMinMax( 0, 1 ); ImageJFunctions.copyToImagePlus( psiCopy ).show(); psiCopy.close(); psiCopy = null; } } while ( i < maxIterations ); cursorPsiGlobal.close(); cursorNextPsiGlobal.close(); nextPsi.close(); return psi; } public static double normAllImages( final ArrayList<LucyRichardsonFFT> data ) { // the individual sums of the overlapping area //final double[] sums = new double[ data.size() ]; final RealSum sum = new RealSum(); // the number of overlapping pixels long count = 0; final ArrayList<Cursor<FloatType>> cursorsImage = new ArrayList<Cursor<FloatType>>(); final ArrayList<Cursor<FloatType>> cursorsWeight = new ArrayList<Cursor<FloatType>>(); for ( final LucyRichardsonFFT fft : data ) { cursorsImage.add( fft.getImage().createCursor() ); if ( fft.getWeight() != null ) cursorsWeight.add( fft.getWeight().createCursor() ); } final Cursor<FloatType> cursor = cursorsImage.get( 0 ); // sum overlapping area individually /*A:*/ while ( cursor.hasNext() ) { for ( final Cursor<FloatType> c : cursorsImage ) c.fwd(); for ( final Cursor<FloatType> c : cursorsWeight ) c.fwd(); // only sum if all views overlap //for ( final Cursor<FloatType> c : cursorsWeight ) // if ( c.getType().get() == 0 ) // continue A; // sum up individual intensities double sumLocal = 0; int countLocal = 0; for ( int i = 0; i < cursorsImage.size(); ++i ) { if ( cursorsWeight.get( i ).getType().get() != 0 ) { sumLocal += cursorsImage.get( i ).getType().get(); countLocal++; } } // at least two overlap if ( countLocal > 1 ) { sum.add( sumLocal ); count += countLocal; } } if ( count == 0 ) return 1; // compute the average sum final double avg = sum.getSum() / (double)count; /* for ( double d : sums ) avgSum += d; avgSum /= (double)(sums.length); // adjust so that the sum of each overlapping area equals the average sum for ( int i = 0; i < sums.length; ++i ) { sums[ i ] = avgSum / sums[ i ]; IJ.log( "Normalizing view " + (i+1) + " with " + sums[ i ] ); } // apply to data over the whole image for ( final Cursor<FloatType> c : cursorsImage ) c.reset(); while ( cursor.hasNext() ) { int i = 0; for ( final Cursor<FloatType> c : cursorsImage ) { c.fwd(); // TODO: this is removed for testing only! //c.getType().set( (float)(c.getType().get() / sums[ i++ ]) ); } } */ // close all cursors for ( final Cursor<FloatType> c : cursorsImage ) c.close(); for ( final Cursor<FloatType> c : cursorsWeight ) c.close(); // return the average intensity in the overlapping area return avg; } final private static BigDecimal sumImage( final Image<FloatType> img ) { BigDecimal sum = new BigDecimal( 0, MathContext.UNLIMITED ); final Cursor<FloatType> cursorImg = img.createCursor(); while ( cursorImg.hasNext() ) { cursorImg.fwd(); sum = sum.add( BigDecimal.valueOf( (double)cursorImg.getType().get() ) ); } cursorImg.close(); return sum; } final private static void normImage( final Image<FloatType> img ) { final BigDecimal sum = sumImage( img ); final Cursor<FloatType> cursor = img.createCursor(); while ( cursor.hasNext() ) { cursor.fwd(); cursor.getType().set( (float) ((double)cursor.getType().get() / sum.doubleValue()) ); } cursor.close(); } }