/*-
* #%L
* Fiji distribution of ImageJ for the life sciences.
* %%
* Copyright (C) 2007 - 2017 Fiji developers.
* %%
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as
* published by the Free Software Foundation, either version 2 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public
* License along with this program. If not, see
* <http://www.gnu.org/licenses/gpl-2.0.html>.
* #L%
*/
package mpicbg.spim.segmentation;
import ij.ImageJ;
import java.util.concurrent.atomic.AtomicInteger;
import mpicbg.imglib.algorithm.integral.IntegralImageLong;
import mpicbg.imglib.container.array.Array;
import mpicbg.imglib.container.array.ArrayContainerFactory;
import mpicbg.imglib.container.basictypecontainer.FloatAccess;
import mpicbg.imglib.container.basictypecontainer.LongAccess;
import mpicbg.imglib.container.basictypecontainer.array.FloatArray;
import mpicbg.imglib.container.basictypecontainer.array.LongArray;
import mpicbg.imglib.cursor.LocalizableByDimCursor;
import mpicbg.imglib.function.Converter;
import mpicbg.imglib.image.Image;
import mpicbg.imglib.image.ImageFactory;
import mpicbg.imglib.io.LOCI;
import mpicbg.imglib.multithreading.SimpleMultiThreading;
import mpicbg.imglib.type.numeric.integer.LongType;
import mpicbg.imglib.type.numeric.real.FloatType;
import mpicbg.imglib.util.Util;
public class IntegralImage3d
{
final public static int getIndex( final int x, final int y, final int z, final int w, final int h )
{
return (x + w * (y + z * h));
}
final public static Image< LongType > compute( final Image< FloatType > img )
{
final Image< LongType > integralTmp;
if ( img.getContainer() instanceof Array )
{
final ImageFactory< LongType > imgFactory = new ImageFactory< LongType >( new LongType(), new ArrayContainerFactory() );
integralTmp = imgFactory.createImage( new int[]{ img.getDimension( 0 ) + 1, img.getDimension( 1 ) + 1, img.getDimension( 2 ) + 1 } );
computeArray( integralTmp, img );
}
else
{
final ImageFactory< LongType > imgFactory = new ImageFactory< LongType >( new LongType(), img.getContainerFactory() );
integralTmp = imgFactory.createImage( new int[]{ img.getDimension( 0 ) + 1, img.getDimension( 1 ) + 1, img.getDimension( 2 ) + 1 } );
compute( integralTmp, img );
}
return integralTmp;
}
final public static void computeIntegralImage( final Image< LongType > integralTmp, final Image< FloatType > img )
{
if ( ( integralTmp.getContainer() instanceof Array ) && ( img.getContainer() instanceof Array ) )
computeArray( integralTmp, img );
else
compute( integralTmp, img );
}
final private static void computeArray( final Image< LongType > integralTmp, final Image< FloatType > img )
{
final Array<LongType, LongAccess> array1 = (Array<LongType, LongAccess>)integralTmp.getContainer();
final LongArray longarray = (LongArray)array1.update( null );
final long[] data = longarray.getCurrentStorageArray();
final Array<FloatType, FloatAccess> array2 = (Array<FloatType, FloatAccess>)img.getContainer();
final FloatArray floatarray = (FloatArray)array2.update( null );
final float[] dataF = floatarray.getCurrentStorageArray();
final int w = integralTmp.getDimension( 0 );
final int h = integralTmp.getDimension( 1 );
final int d = integralTmp.getDimension( 2 );
final int wf = img.getDimension( 0 );
final int hf = img.getDimension( 1 );
final AtomicInteger ai = new AtomicInteger(0);
final Thread[] threads = SimpleMultiThreading.newThreads( );
final int numThreads = threads.length;
//
// sum over x
//
for (int ithread = 0; ithread < threads.length; ++ithread)
threads[ithread] = new Thread(new Runnable()
{
public void run()
{
// Thread ID
final int myNumber = ai.getAndIncrement();
for ( int z = 1; z < d; ++z )
{
if ( z % numThreads == myNumber )
{
for ( int y = 1; y < h; ++y )
{
int indexIn = getIndex( 0, y - 1, z - 1, wf, hf );
int indexOut = getIndex( 1, y, z, w, h );
// compute the first pixel
long sum = (int)( dataF[ indexIn ] );
data[ indexOut ] = sum;
for ( int x = 2; x < w; ++x )
{
++indexIn;
++indexOut;
sum += (int)( dataF[ indexIn ] );
data[ indexOut ] = sum;
}
}
}
}
}
});
SimpleMultiThreading.startAndJoin( threads );
//
// sum over y
//
ai.set( 0 );
for (int ithread = 0; ithread < threads.length; ++ithread)
threads[ithread] = new Thread(new Runnable()
{
public void run()
{
// Thread ID
final int myNumber = ai.getAndIncrement();
//int index = 0;
for ( int z = 1; z < d; ++z )
{
if ( z % numThreads == myNumber )
{
for ( int x = 1; x < w; ++x )
{
int index = getIndex( x, 1, z, w, h );
// init sum on first pixel that is not zero
long sum = data[ index ];
for ( int y = 2; y < h; ++y )
{
index += w;
sum += data[ index ];
data[ index ] = sum;
}
}
}
}
}
});
SimpleMultiThreading.startAndJoin( threads );
//
// sum over z
//
ai.set( 0 );
for (int ithread = 0; ithread < threads.length; ++ithread)
threads[ithread] = new Thread(new Runnable()
{
public void run()
{
// Thread ID
final int myNumber = ai.getAndIncrement();
//int index = 0;
final int inc = w*h;
for ( int y = 1; y < h; ++y )
{
if ( y % numThreads == myNumber )
{
for ( int x = 1; x < w; ++x )
{
int index = getIndex( x, y, 1, w, h );
//System.out.println( index + " " + data[ index ] );
// init sum on first pixel that is not zero
long sum = data[ index ];
for ( int z = 2; z < d; ++z )
{
index += inc;
//System.out.println( index + " " + data[ index ] );
sum += data[ index ];
data[ index ] = sum;
}
//System.out.println();
}
}
}
}
});
SimpleMultiThreading.startAndJoin( threads );
}
final private static void compute( final Image< LongType > integralTmp, final Image< FloatType > img )
{
final int w = integralTmp.getDimension( 0 );
final int h = integralTmp.getDimension( 1 );
final int d = integralTmp.getDimension( 2 );
final int wf = img.getDimension( 0 );
final int hf = img.getDimension( 1 );
final AtomicInteger ai = new AtomicInteger(0);
final Thread[] threads = SimpleMultiThreading.newThreads( );
final int numThreads = threads.length;
//
// sum over x
//
for (int ithread = 0; ithread < threads.length; ++ithread)
threads[ithread] = new Thread(new Runnable()
{
public void run()
{
// Thread ID
final int myNumber = ai.getAndIncrement();
final LocalizableByDimCursor< LongType > data = integralTmp.createLocalizableByDimCursor();
final LocalizableByDimCursor< FloatType > dataF = img.createLocalizableByDimCursor();
final int[] pos = new int[ 3 ];
final int[] posF = new int[ 3 ];
for ( int z = 1; z < d; ++z )
{
if ( z % numThreads == myNumber )
{
posF[ 2 ] = z - 1;
pos[ 2 ] = z;
for ( int y = 1; y < h; ++y )
{
posF[ 1 ] = y-1;
posF[ 0 ] = 0;
pos[ 1 ] = y;
pos[ 0 ] = 1;
dataF.setPosition( posF );
data.setPosition( pos );
//int indexIn = getIndex( 0, y - 1, z - 1, wf, hf );
//int indexOut = getIndex( 1, y, z, w, h );
// compute the first pixel
long sum = (int)( dataF.getType().get() );
data.getType().set( sum );
for ( int x = 2; x < w; ++x )
{
data.fwd( 0 );
dataF.fwd( 0 );
sum += (int)( dataF.getType().get() );
data.getType().set( sum );
}
}
}
}
data.close();
dataF.close();
}
});
SimpleMultiThreading.startAndJoin( threads );
//
// sum over y
//
ai.set( 0 );
for (int ithread = 0; ithread < threads.length; ++ithread)
threads[ithread] = new Thread(new Runnable()
{
public void run()
{
// Thread ID
final int myNumber = ai.getAndIncrement();
//int index = 0;
final LocalizableByDimCursor< LongType > data = integralTmp.createLocalizableByDimCursor();
final int[] pos = new int[ 3 ];
for ( int z = 1; z < d; ++z )
{
if ( z % numThreads == myNumber )
{
pos[ 2 ] = z;
for ( int x = 1; x < w; ++x )
{
pos[ 0 ] = x;
pos[ 1 ] = 1;
data.setPosition( pos );
//int index = getIndex( x, 1, z, w, h );
// init sum on first pixel that is not zero
long sum = data.getType().get();
for ( int y = 2; y < h; ++y )
{
data.fwd( 1 );
sum += data.getType().get();
data.getType().set( sum );
}
}
}
}
data.close();
}
});
SimpleMultiThreading.startAndJoin( threads );
//
// sum over z
//
ai.set( 0 );
for (int ithread = 0; ithread < threads.length; ++ithread)
threads[ithread] = new Thread(new Runnable()
{
public void run()
{
// Thread ID
final int myNumber = ai.getAndIncrement();
//int index = 0;
final int inc = w*h;
final LocalizableByDimCursor< LongType > data = integralTmp.createLocalizableByDimCursor();
final int[] pos = new int[ 3 ];
for ( int y = 1; y < h; ++y )
{
if ( y % numThreads == myNumber )
{
pos[ 1 ] = y;
for ( int x = 1; x < w; ++x )
{
pos[ 0 ] = x;
pos[ 2 ] = 1;
data.setPosition( pos );
//int index = getIndex( x, y, 1, w, h );
// init sum on first pixel that is not zero
long sum = data.getType().get();
for ( int z = 2; z < d; ++z )
{
//index += inc;
data.fwd( 2 );
//System.out.println( index + " " + data[ index ] );
sum += data.getType().get();//[ index ];
data.getType().set( sum );
}
//System.out.println();
}
}
}
}
});
SimpleMultiThreading.startAndJoin( threads );
}
public static void main( String[] args )
{
new ImageJ();
//Image< FloatType > img = new ImageFactory< FloatType >( new FloatType(), new ArrayContainerFactory() ).createImage( new int[]{ 2, 3, 4 } );
Image< FloatType > img = LOCI.openLOCIFloatType( "/Users/preibischs/Documents/Microscopy/SPIM/HisYFP-SPIM/spim_TL18_Angle0.tif", new ArrayContainerFactory() );
//int i = 1;
//for ( final FloatType t : img )
/// t.set( i++ );
long t = 0;
System.out.println( "new implementation" );
for ( int i = 0; i < 10; ++i )
{
long t1 = System.currentTimeMillis();
final Image< LongType > integralImg = compute( img );
integralImg.close();
long t2 = System.currentTimeMillis();
System.out.println( (t2 - t1) + " ms" );
t += t2-t1;
}
System.out.println( "avg: " + (t/10) + " ms" );
System.out.println( "\nold implementation" );
t = 0;
for ( int i = 0; i < 10; ++i )
{
long t1 = System.currentTimeMillis();
final IntegralImageLong< FloatType > intImg = new IntegralImageLong<FloatType>( img, new Converter< FloatType, LongType >()
{
@Override
public void convert( final FloatType input, final LongType output ) { output.set( Util.round( input.get() ) ); }
} );
intImg.process();
final Image< LongType > integralImg = intImg.getResult();
integralImg.close();
long t2 = System.currentTimeMillis();
System.out.println( (t2 - t1) + " ms" );
t += t2-t1;
}
System.out.println( "avg: " + (t/10) + " ms" );
/*
final IntegralImageLong< FloatType > intImg = new IntegralImageLong<FloatType>( img, new Converter< FloatType, LongType >()
{
@Override
public void convert( final FloatType input, final LongType output ) { output.set( Util.round( input.get() ) ); }
} );
intImg.process();
final Image< LongType > integralImg = intImg.getResult();
final Image< LongType > integralImgNew = computeArray( img );
ImageJFunctions.show( img ).setTitle( "img" );
ImageJFunctions.show( integralImg ).setTitle( "integral_correct");
ImageJFunctions.show( integralImgNew ).setTitle( "integral_new");
final Image< LongType> diff = integralImg.createNewImage();
final Cursor< LongType > c1 = integralImg.createCursor();
final Cursor< LongType > c2 = integralImgNew.createCursor();
final Cursor< LongType > cd = diff.createCursor();
while ( c1.hasNext() )
{
cd.next().set( c1.next().get() - c2.next().get() );
}
ImageJFunctions.show( diff ).setTitle( "integral_diff");
*/
}
}