/*- * #%L * Fiji distribution of ImageJ for the life sciences. * %% * Copyright (C) 2007 - 2017 Fiji developers. * %% * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as * published by the Free Software Foundation, either version 2 of the * License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public * License along with this program. If not, see * <http://www.gnu.org/licenses/gpl-2.0.html>. * #L% */ package spim.process.fusion.deconvolution; import java.util.ArrayList; import java.util.List; import java.util.Vector; import java.util.concurrent.Callable; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import ij.ImageJ; import mpicbg.spim.io.IOFunctions; import net.imglib2.Cursor; import net.imglib2.RandomAccess; import net.imglib2.RandomAccessibleInterval; import net.imglib2.img.Img; import net.imglib2.img.ImgFactory; import net.imglib2.img.array.ArrayImgs; import net.imglib2.img.display.imagej.ImageJFunctions; import net.imglib2.type.numeric.real.FloatType; import net.imglib2.view.Views; import spim.Threads; import spim.process.fusion.FusionHelper; import spim.process.fusion.ImagePortion; import spim.process.fusion.weights.BlendingRealRandomAccess; import spim.process.fusion.weights.NormalizingRandomAccessibleInterval; public class WeightNormalizer { final List< RandomAccessibleInterval< FloatType > > weights; final Img< FloatType > sumWeights; int minOverlappingViews; double avgOverlappingViews; public WeightNormalizer( final List< RandomAccessibleInterval< FloatType > > weights ) { this.weights = weights; this.sumWeights = null; } public WeightNormalizer( final List< RandomAccessibleInterval< FloatType > > weights, final ImgFactory< FloatType > factory ) { this.weights = weights; this.sumWeights = factory.create( weights.get( 0 ), new FloatType() ); } public int getMinOverlappingViews() { return minOverlappingViews; } public double getAvgOverlappingViews() { return avgOverlappingViews; } public Img< FloatType > getSumWeights() { return sumWeights; } public boolean process() { // split up into many parts for multithreading final Vector< ImagePortion > portions = FusionHelper.divideIntoPortions( Views.iterable( weights.get( 0 ) ).size(), Threads.numThreads() * 2 ); // set up executor service final ExecutorService taskExecutor = Executors.newFixedThreadPool( Threads.numThreads() ); final ArrayList< Callable< double[] > > tasks = new ArrayList< Callable< double[] > >(); for ( final ImagePortion portion : portions ) { if ( sumWeights == null ) tasks.add( new ApplyDirectly( portion ) ); else tasks.add( new ComputeSumImage( portion, sumWeights ) ); } // run threads try { // invokeAll() returns when all tasks are complete final List< Future< double[] > > futures = taskExecutor.invokeAll( tasks ); this.minOverlappingViews = weights.size(); this.avgOverlappingViews = 0; for ( final Future< double[] > f : futures ) { final double[] minAvg = f.get(); this.minOverlappingViews = Math.min( this.minOverlappingViews, (int)Math.round( minAvg[ 0 ] ) ); this.avgOverlappingViews += minAvg[ 1 ]; } this.avgOverlappingViews /= futures.size(); } catch ( final Exception e ) { IOFunctions.println( "Failed to compute weight normalization for deconvolution: " + e ); e.printStackTrace(); return false; } taskExecutor.shutdown(); // set the normalizing interval if ( sumWeights != null ) { for ( int i = 0; i < weights.size(); ++i ) { final RandomAccessibleInterval< FloatType > w = weights.get( i ); final NormalizingRandomAccessibleInterval< FloatType > nw = new NormalizingRandomAccessibleInterval< FloatType >( w, sumWeights, new FloatType() ); weights.set( i, nw ); } } return true; } final private static void apply( final ArrayList< Cursor< FloatType > > cursors, final double sumW ) { for ( final Cursor< FloatType > c : cursors ) c.get().set( (float)( c.get().get() / sumW ) ); } final private class ApplyDirectly implements Callable< double[] > { final ImagePortion portion; public ApplyDirectly( final ImagePortion portion ) { this.portion = portion; } @Override public double[] call() throws Exception { final ArrayList< Cursor< FloatType > > cursors = new ArrayList< Cursor< FloatType > >(); for ( final RandomAccessibleInterval< FloatType > imgW : weights ) { final Cursor< FloatType > c = Views.iterable( imgW ).cursor(); c.jumpFwd( portion.getStartPosition() ); cursors.add( c ); } int minNumViews = cursors.size(); long countViews = 0; for ( long j = 0; j < portion.getLoopSize(); ++j ) { double sumW = 0; int count = 0; for ( final Cursor< FloatType > c : cursors ) { final float w = c.next().get(); sumW += w; if ( w > 0 ) ++count; } countViews += count; minNumViews = Math.min( minNumViews, count ); // something in between ... I would say, now we have hard edges where the image stacks end //if ( sumW > 1 ) apply( cursors, sumW ); } final double avgNumViews = (double)countViews / (double)( portion.getLoopSize() ); return new double[]{ minNumViews, avgNumViews }; } } final private class ComputeSumImage implements Callable< double[] > { final ImagePortion portion; final Img< FloatType > sumWeights; public ComputeSumImage( final ImagePortion portion, final Img< FloatType > sumWeights ) { this.portion = portion; this.sumWeights = sumWeights; } @Override public double[] call() throws Exception { final ArrayList< Cursor< FloatType > > cursors = new ArrayList< Cursor< FloatType > >(); final RandomAccess< FloatType > ra = sumWeights.randomAccess(); for ( int i = 0; i < weights.size(); ++i ) { final RandomAccessibleInterval< FloatType > imgW = weights.get( i ); final Cursor< FloatType > c; if ( i == 0 ) c = Views.iterable( imgW ).localizingCursor(); else c = Views.iterable( imgW ).cursor(); c.jumpFwd( portion.getStartPosition() ); cursors.add( c ); } // the first one is a localizablecursor final Cursor< FloatType > firstCursor = cursors.get( 0 ); int minNumViews = cursors.size(); long countViews = 0; for ( long j = 0; j < portion.getLoopSize(); ++j ) { double sumW = 0; int count = 0; for ( final Cursor< FloatType > c : cursors ) { final float w = c.next().get(); sumW += w; if ( w > 0 ) ++count; } countViews += count; minNumViews = Math.min( minNumViews, count ); ra.setPosition( firstCursor ); if ( sumW > 1 ) ra.get().set( (float)sumW ); else ra.get().setOne(); } final double avgNumViews = (double)countViews / (double)( portion.getLoopSize() ); return new double[]{ minNumViews, avgNumViews }; } } public static void main( String[] args ) { new ImageJ(); Img< FloatType > img = ArrayImgs.floats( 500, 500 ); BlendingRealRandomAccess blend = new BlendingRealRandomAccess( img, new float[]{ 100, 0 }, new float[]{ 12, 150 } ); Cursor< FloatType > c = img.localizingCursor(); while ( c.hasNext() ) { c.fwd(); blend.setPosition( c ); c.get().setReal( blend.get().getRealFloat() ); } ImageJFunctions.show( img ); } }