/*-
* #%L
* Fiji distribution of ImageJ for the life sciences.
* %%
* Copyright (C) 2007 - 2017 Fiji developers.
* %%
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as
* published by the Free Software Foundation, either version 2 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public
* License along with this program. If not, see
* <http://www.gnu.org/licenses/gpl-2.0.html>.
* #L%
*/
package spim.process.cuda;
import java.util.ArrayList;
import java.util.Vector;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import mpicbg.spim.io.IOFunctions;
import net.imglib2.Cursor;
import net.imglib2.FinalInterval;
import net.imglib2.RandomAccess;
import net.imglib2.RandomAccessible;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.img.array.ArrayImg;
import net.imglib2.img.array.ArrayImgs;
import net.imglib2.img.basictypeaccess.array.FloatArray;
import net.imglib2.img.display.imagej.ImageJFunctions;
import net.imglib2.iterator.LocalizingZeroMinIntervalIterator;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.view.Views;
import spim.Threads;
import spim.process.fusion.FusionHelper;
import spim.process.fusion.ImagePortion;
public class Block
{
/**
* the number of dimensions of this block
*/
final int numDimensions;
/**
* The dimensions of the block
*/
final long[] blockSize;
/**
* The offset in coordinates (coordinate system of the original image)
*/
final long[] offset;
/**
* The effective size that can be convolved (depends on the kernelsize)
*/
final long[] effectiveSize;
/**
* The effective offset, i.e. where the useful convolved data starts (coordinate system of the original image)
*/
final long[] effectiveOffset;
/**
* The effective offset, i.e. where the useful convoved data starts (local coordinate system)
*/
final long[] effectiveLocalOffset;
/**
* If the blocks that cover the image are precise or an approximation
*/
final boolean isPrecise;
final Vector< ImagePortion > portions;
final ExecutorService taskExecutor;
public Block(
final long[] blockSize,
final long[] offset,
final long[] effectiveSize,
final long[] effectiveOffset,
final long[] effectiveLocalOffset,
final boolean isPrecise )
{
this.numDimensions = blockSize.length;
this.blockSize = blockSize.clone();
this.offset = offset.clone();
this.effectiveSize = effectiveSize.clone();
this.effectiveOffset = effectiveOffset.clone();
this.effectiveLocalOffset = effectiveLocalOffset.clone();
this.isPrecise = isPrecise;
long n = blockSize[ 0 ];
for ( int d = 1; d < numDimensions; ++d )
n *= blockSize[ d ];
// split up into many parts for multithreading
this.portions = FusionHelper.divideIntoPortions( n, Threads.numThreads() * 2 );
this.taskExecutor = Executors.newFixedThreadPool( Threads.numThreads() );
}
public long[] getBlockSize()
{
final long[] dim = new long[ blockSize.length ];
for ( int d = 0; d < dim.length; ++d )
dim[ d ] = blockSize[ d ];
return dim;
}
@Override
public void finalize() { taskExecutor.shutdown(); }
/**
* @return - if the blocks that cover an area/volume/... are precise, i.e. if they are identical to performing the convolution on the entire image. Non-precise blocks do not need an outofbounds, they will not query data from outside of the blocked area.
*/
public boolean isPrecise() { return isPrecise; }
/**
* @param source - needs to be extended with an OutOfBounds in case the block extends past the boundaries of the RandomAccessibleInterval
* @param block - the Block to copy it to
*/
public void copyBlock( final RandomAccessible< FloatType > source, final RandomAccessibleInterval< FloatType > block )
{
// set up threads
final ArrayList< Callable< Boolean > > tasks = new ArrayList< Callable< Boolean > >();
for ( int i = 0; i < portions.size(); ++i )
{
final int threadIdx = i;
tasks.add( new Callable< Boolean >()
{
@SuppressWarnings("unchecked")
@Override
public Boolean call() throws Exception
{
if ( source.numDimensions() == 3 && ArrayImg.class.isInstance( block ) )
copy3dArray( threadIdx, portions.size(), source, (ArrayImg< FloatType, ?>)block, offset );
else
{
final ImagePortion portion = portions.get( threadIdx );
copy( portion.getStartPosition(), portion.getLoopSize(), source, block, offset);
}
return true;
}
});
}
try
{
// invokeAll() returns when all tasks are complete
taskExecutor.invokeAll( tasks );
}
catch ( final InterruptedException e )
{
IOFunctions.println( "Failed to copy block: " + e );
e.printStackTrace();
return;
}
}
public void pasteBlock( final RandomAccessibleInterval< FloatType > target, final RandomAccessibleInterval< FloatType > block )
{
// set up threads
final ArrayList< Callable< Boolean > > tasks = new ArrayList< Callable< Boolean > >();
for ( int i = 0; i < portions.size(); ++i )
{
final int threadIdx = i;
tasks.add( new Callable< Boolean >()
{
@SuppressWarnings("unchecked")
@Override
public Boolean call() throws Exception
{
if ( target.numDimensions() == 3 && ArrayImg.class.isInstance( target ) && ArrayImg.class.isInstance( block ) )
paste3d( threadIdx, portions.size(), (ArrayImg< FloatType, ?>)target, (ArrayImg< FloatType, ?>)block, effectiveOffset, effectiveSize, effectiveLocalOffset );
else
{
final ImagePortion portion = portions.get( threadIdx );
paste( portion.getStartPosition(), portion.getLoopSize(), target, block, effectiveOffset, effectiveSize, effectiveLocalOffset );
}
return true;
}
});
}
try
{
// invokeAll() returns when all tasks are complete
taskExecutor.invokeAll( tasks );
}
catch ( final InterruptedException e )
{
IOFunctions.println( "Failed to paste block: " + e );
e.printStackTrace();
return;
}
}
private static final void copy( final long start, final long loopSize, final RandomAccessible< FloatType > source, final RandomAccessibleInterval< FloatType > block, final long[] offset )
{
final int numDimensions = source.numDimensions();
final Cursor< FloatType > cursor = Views.iterable( block ).localizingCursor();
// define where we will query the RandomAccess on the source
// (we say it is the entire block, although it is just a part of it,
// but which part depends on the underlying container)
final long[] min = new long[ numDimensions ];
final long[] max = new long[ numDimensions ];
for ( int d = 0; d < numDimensions; ++d )
{
min[ d ] = offset[ d ];
max[ d ] = offset[ d ] + block.dimension( d ) - 1;
}
final RandomAccess< FloatType > randomAccess = source.randomAccess( new FinalInterval( min, max ) );
cursor.jumpFwd( start );
final long[] tmp = new long[ numDimensions ];
for ( long l = 0; l < loopSize; ++l )
{
cursor.fwd();
cursor.localize( tmp );
for ( int d = 0; d < numDimensions; ++d )
tmp[ d ] += offset[ d ];
randomAccess.setPosition( tmp );
cursor.get().set( randomAccess.get() );
}
}
private static final void copy3dArray( final int threadIdx, final int numThreads, final RandomAccessible< FloatType > source, final ArrayImg< FloatType, ? > block, final long[] offset )
{
final int w = (int)block.dimension( 0 );
final int h = (int)block.dimension( 1 );
final int d = (int)block.dimension( 2 );
final long offsetX = offset[ 0 ];
final long offsetY = offset[ 1 ];
final long offsetZ = offset[ 2 ];
final float[] blockArray = ((FloatArray)block.update( null ) ).getCurrentStorageArray();
// define where we will query the RandomAccess on the source
final FinalInterval interval = new FinalInterval( new long[] { offsetX, offsetY, offsetZ }, new long[] { offsetX + w - 1, offsetY + h - 1, offsetZ + d - 1 } );
final RandomAccess< FloatType > randomAccess = source.randomAccess( interval );
final long[] tmp = new long[]{ offsetX, offsetY, 0 };
for ( int z = threadIdx; z < d; z += numThreads )
{
tmp[ 2 ] = z + offsetZ;
randomAccess.setPosition( tmp );
int i = z * h * w;
for ( int y = 0; y < h; ++y )
{
randomAccess.setPosition( offsetX, 0 );
for ( int x = 0; x < w; ++x )
{
blockArray[ i++ ] = randomAccess.get().get();
randomAccess.fwd( 0 );
}
randomAccess.move( -w, 0 );
randomAccess.fwd( 1 );
}
}
}
private static final void paste( final long start, final long loopSize, final RandomAccessibleInterval< FloatType > target, final RandomAccessibleInterval< FloatType > block,
final long[] effectiveOffset, final long[] effectiveSize, final long[] effectiveLocalOffset )
{
final int numDimensions = target.numDimensions();
// iterate over effective size
final LocalizingZeroMinIntervalIterator cursor = new LocalizingZeroMinIntervalIterator( effectiveSize );
// read from block
final RandomAccess<FloatType> blockRandomAccess = block.randomAccess();
// write to target
final RandomAccess<FloatType> targetRandomAccess = target.randomAccess();
cursor.jumpFwd( start );
final long[] tmp = new long[ numDimensions ];
for ( long l = 0; l < loopSize; ++l )
{
cursor.fwd();
cursor.localize( tmp );
// move to the relative local offset where the real data starts
for ( int d = 0; d < numDimensions; ++d )
tmp[ d ] += effectiveLocalOffset[ d ];
blockRandomAccess.setPosition( tmp );
// move to the right position in the image
for ( int d = 0; d < numDimensions; ++d )
tmp[ d ] += effectiveOffset[ d ] - effectiveLocalOffset[ d ];
targetRandomAccess.setPosition( tmp );
// write the pixel
targetRandomAccess.get().set( blockRandomAccess.get() );
}
}
private static final void paste3d( final int threadIdx, final int numThreads, final ArrayImg< FloatType, ? > target, final ArrayImg< FloatType, ? > block,
final long[] effectiveOffset, final long[] effectiveSize, final long[] effectiveLocalOffset )
{
// min position in the output
final int minX = (int)effectiveOffset[ 0 ];
final int minY = (int)effectiveOffset[ 1 ];
final int minZ = (int)effectiveOffset[ 2 ];
// max+1 of the output area
final int maxY = (int)effectiveSize[ 1 ] + minY;
final int maxZ = (int)effectiveSize[ 2 ] + minZ;
// size of the output area
final int sX = (int)effectiveSize[ 0 ];
// min position in the output
final int minXb = (int)effectiveLocalOffset[ 0 ];
final int minYb = (int)effectiveLocalOffset[ 1 ];
final int minZb = (int)effectiveLocalOffset[ 2 ];
// size of the target image
final int w = (int)target.dimension( 0 );
final int h = (int)target.dimension( 1 );
// size of the block image
final int wb = (int)block.dimension( 0 );
final int hb = (int)block.dimension( 1 );
final float[] blockArray = ((FloatArray)block.update( null ) ).getCurrentStorageArray();
final float[] targetArray = ((FloatArray)target.update( null ) ).getCurrentStorageArray();
for ( int z = minZ + threadIdx; z < maxZ; z += numThreads )
{
final int zBlock = z - minZ + minZb;
int iTarget = z * h * w + minY * w + minX;
int iBlock = zBlock * hb * wb + minYb * wb + minXb;
for ( int y = minY; y < maxY; ++y )
{
copyX( blockArray, targetArray, sX, iTarget, iBlock );
iTarget += w;
iBlock += wb;
}
}
}
private static final void copyX( final float[] blockArray, final float[] targetArray, final int count, int iTarget, int iBlock )
{
for ( int x = 0; x < count; ++x )
targetArray[ iTarget++ ] = blockArray[ iBlock++ ];
}
public static void main( String[] args )
{
// define the blocksize so that it is one single block
final RandomAccessibleInterval< FloatType > block = ArrayImgs.floats( 384, 384 );
final long[] blockSize = new long[ block.numDimensions() ];
block.dimensions( blockSize );
final RandomAccessibleInterval< FloatType > image = ArrayImgs.floats( 1024, 1024 );
final long[] imgSize = new long[ image.numDimensions() ];
image.dimensions( imgSize );
// whatever the kernel size is (extra size/2 in general)
final long[] kernelSize = new long[]{ 16, 32 };
final BlockGeneratorFixedSizePrecise blockGenerator = new BlockGeneratorFixedSizePrecise( blockSize );
final Block[] blocks = blockGenerator.divideIntoBlocks( imgSize, kernelSize );
int i = 0;
for ( final Block b : blocks )
{
// copy data from the image to the block (including extra space for outofbounds/real image data depending on kernel size)
b.copyBlock( Views.extendMirrorDouble( image ), block );
// do something with the block (e.g. also multithreaded, cluster, ...)
for ( final FloatType f : Views.iterable( block ) )
f.set( i );
++i;
// write the block back (use a temporary image if multithreaded or in general not all are copied first)
b.pasteBlock( image, block );
}
ImageJFunctions.show( image );
}
}