/** * */ package org.janelia.intensity; import ij.IJ; import ij.ImageJ; import ij.ImagePlus; import ij.ImageStack; import ij.gui.GenericDialog; import ij.gui.Roi; import ij.measure.Calibration; import ij.process.ColorProcessor; import ij.process.FloatProcessor; import ini.trakem2.Project; import ini.trakem2.display.Display; import ini.trakem2.display.Displayable; import ini.trakem2.display.Layer; import ini.trakem2.display.LayerSet; import ini.trakem2.display.Patch; import ini.trakem2.persistence.FSLoader; import ini.trakem2.plugin.TPlugIn; import ini.trakem2.utils.Utils; import java.awt.Rectangle; import java.io.File; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map.Entry; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import mpicbg.models.Affine1D; import mpicbg.models.AffineModel1D; import mpicbg.models.IdentityModel; import mpicbg.models.IllDefinedDataPointsException; import mpicbg.models.InterpolatedAffineModel1D; import mpicbg.models.Model; import mpicbg.models.NotEnoughDataPointsException; import mpicbg.models.Point; import mpicbg.models.PointMatch; import mpicbg.models.Tile; import mpicbg.models.TileConfiguration; import mpicbg.models.TranslationModel1D; import net.imglib2.img.list.ListImg; import net.imglib2.img.list.ListRandomAccess; import net.imglib2.util.ValuePair; /** * * @author Stephan Saalfeld saalfelds@janelia.hhmi.org * @author Philipp Hanslovsky */ public class MatchIntensities implements TPlugIn { final private class Matcher implements Runnable { final private Rectangle roi; final private ValuePair< Patch, Patch > patchPair; final private HashMap< Patch, ArrayList< Tile< ? > > > coefficientsTiles; final private PointMatchFilter filter; final private double scale; final private int numCoefficients; public Matcher( final Rectangle roi, final ValuePair< Patch, Patch > patchPair, final HashMap< Patch, ArrayList< Tile< ? > > > coefficientsTiles, final PointMatchFilter filter, final double scale, final int numCoefficients ) { this.roi = roi; this.patchPair = patchPair; this.coefficientsTiles = coefficientsTiles; this.filter = filter; this.scale = scale; this.numCoefficients = numCoefficients; } @Override public void run() { final Patch p1 = patchPair.getA(); final Patch p2 = patchPair.getB(); final Rectangle box1 = p1.getBoundingBox().intersection( roi ); /* get the coefficient tiles */ final ArrayList< Tile< ? > > p1CoefficientsTiles = coefficientsTiles.get( p1 ); /* render intersection */ final Rectangle box2 = p2.getBoundingBox(); final Rectangle box = box1.intersection( box2 ); final int w = ( int ) ( box.width * scale + 0.5 ); final int h = ( int ) ( box.height * scale + 0.5 ); final int n = w * h; final FloatProcessor pixels1 = new FloatProcessor( w, h ); final FloatProcessor weights1 = new FloatProcessor( w, h ); final ColorProcessor coefficients1 = new ColorProcessor( w, h ); final FloatProcessor pixels2 = new FloatProcessor( w, h ); final FloatProcessor weights2 = new FloatProcessor( w, h ); final ColorProcessor coefficients2 = new ColorProcessor( w, h ); Render.render( p1, numCoefficients, numCoefficients, pixels1, weights1, coefficients1, box.x, box.y, scale ); Render.render( p2, numCoefficients, numCoefficients, pixels2, weights2, coefficients2, box.x, box.y, scale ); /* * generate a matrix of all coefficients in p1 to all * coefficients in p2 to store matches */ final ArrayList< ArrayList< PointMatch > > list = new ArrayList< ArrayList< PointMatch > >(); for ( int i = 0; i < numCoefficients * numCoefficients * numCoefficients * numCoefficients; ++i ) list.add( new ArrayList< PointMatch >() ); final ListImg< ArrayList< PointMatch > > matrix = new ListImg< ArrayList< PointMatch > >( list, numCoefficients * numCoefficients, numCoefficients * numCoefficients ); final ListRandomAccess< ArrayList< PointMatch > > ra = matrix.randomAccess(); /* * iterate over all pixels and feed matches into the match * matrix */ for ( int i = 0; i < n; ++i ) { final int c1 = coefficients1.get( i ); if ( c1 > 0 ) { final int c2 = coefficients2.get( i ); if ( c2 > 0 ) { final double w1 = weights1.getf( i ); if ( w1 > 0 ) { final double w2 = weights2.getf( i ); if ( w2 > 0 ) { final double p = pixels1.getf( i ); final double q = pixels2.getf( i ); final PointMatch pq = new PointMatch( new Point( new double[] { p } ), new Point( new double[] { q } ), w1 * w2 ); /* first label is 1 */ ra.setPosition( c1 - 1, 0 ); ra.setPosition( c2 - 1, 1 ); ra.get().add( pq ); } } } } } /* filter matches */ final ArrayList< PointMatch > inliers = new ArrayList< PointMatch >(); for ( final ArrayList< PointMatch > candidates : matrix ) { inliers.clear(); filter.filter( candidates, inliers ); candidates.clear(); candidates.addAll( inliers ); } /* get the coefficient tiles of p2 */ final ArrayList< Tile< ? > > p2CoefficientsTiles = coefficientsTiles.get( p2 ); /* connect tiles across patches */ for ( int i = 0; i < numCoefficients * numCoefficients; ++i ) { final Tile< ? > t1 = p1CoefficientsTiles.get( i ); ra.setPosition( i, 0 ); for ( int j = 0; j < numCoefficients * numCoefficients; ++j ) { ra.setPosition( j, 1 ); final ArrayList< PointMatch > matches = ra.get(); if ( matches.size() > 0 ) { final Tile< ? > t2 = p2CoefficientsTiles.get( j ); synchronized ( MatchIntensities.this ) { t1.connect( t2, ra.get() ); IJ.log( "Connected patch " + p1.getId() + ", coefficient " + i + " + patch " + p2.getId() + ", coefficient " + j + " by " + matches.size() + " samples." ); } } } } } } protected LayerSet layerset = null; static protected int numCoefficients = 8; static protected double lambda1 = 0.01; static protected double lambda2 = 0.01; static protected double neighborWeight = 0.1; static protected int radius = 5; static protected int iterations = 2000; static protected double scale = -1; private Layer currentLayer( final Object... params ) { final Layer layer; if ( params != null && params[ 0 ] != null ) { final Object param = params[ 0 ]; if ( Layer.class.isInstance( param ) ) layer = ( Layer ) param; else if ( LayerSet.class.isInstance( param ) ) layer = ( ( LayerSet ) param ).getLayer( 0 ); else if ( Displayable.class.isInstance( param ) ) layer = ( ( Displayable ) param ).getLayer(); else layer = null; } else { final Display front = Display.getFront(); if ( front == null ) layer = Project.getProjects().get( 0 ).getRootLayerSet().getLayer( 0 ); else layer = front.getLayer(); } return layer; } private static Rectangle getRoi( final LayerSet layerset ) { final Roi roi; final Display front = Display.getFront(); if ( front == null ) roi = null; else roi = front.getRoi(); if ( roi == null ) return new Rectangle( 0, 0, ( int ) layerset.getLayerWidth(), ( int ) layerset.getLayerHeight() ); else return roi.getBounds(); } private double suggestScale( final List< Layer > layers ) { if ( layers.size() < 2 ) return 0.0; final Layer layer1 = layers.get( 0 ); final Layer layer2 = layers.get( 1 ); final Calibration calib = layer1.getParent().getCalibration(); final double width = ( calib.pixelWidth + calib.pixelHeight ) * 0.5; final double depth = calib.pixelDepth; return ( layer2.getZ() - layer1.getZ() ) * width / depth; } @Override public boolean setup( final Object... params ) { if ( params != null && params[ 0 ] != null ) { final Object param = params[ 0 ]; if ( LayerSet.class.isInstance( param ) ) layerset = ( LayerSet ) param; else if ( Displayable.class.isInstance( param ) ) layerset = ( ( Displayable ) param ).getLayerSet(); else return false; } else { final Display front = Display.getFront(); if ( front == null ) layerset = Project.getProjects().get( 0 ).getRootLayerSet(); else layerset = front.getLayerSet(); } return true; } final static protected < T extends Model< T > & Affine1D< T > > HashMap< Patch, ArrayList< Tile< T > > > generateCoefficientsTiles( final Collection< Patch > patches, final T template, final int nCoefficients ) { final HashMap< Patch, ArrayList< Tile< T > > > map = new HashMap< Patch, ArrayList< Tile< T > > >(); for ( final Patch p : patches ) { final ArrayList< Tile< T > > coefficientModels = new ArrayList< Tile< T > >(); for ( int i = 0; i < nCoefficients; ++i ) coefficientModels.add( new Tile< T >( template.copy() ) ); map.put( p, coefficientModels ); } return map; } final static protected void identityConnect( final Tile< ? > t1, final Tile< ? > t2, final double weight ) { final ArrayList< PointMatch > matches = new ArrayList< PointMatch >(); matches.add( new PointMatch( new Point( new double[] { 0 } ), new Point( new double[] { 0 } ) ) ); matches.add( new PointMatch( new Point( new double[] { 1 } ), new Point( new double[] { 1 } ) ) ); t1.connect( t2, matches ); } @Override public Object invoke( final Object... params ) { if ( !setup( params ) ) return null; final Layer layer = currentLayer( params ); final GenericDialog gd = new GenericDialog( "Match intensities" ); Utils.addLayerRangeChoices( layer, gd ); gd.addMessage( "Layer range :" ); gd.addNumericField( "scale : ", scale > 0 ? scale : suggestScale( layerset.getLayers() ), 3, 6, "" ); gd.addNumericField( "coefficient resolution : ", numCoefficients, 0, 6, "" ); gd.addNumericField( "test_maximally :", radius, 0, 6, "layers" ); gd.addMessage( "Optimizer :" ); gd.addNumericField( "iterations :", iterations, 0, 6, "" ); gd.addNumericField( "scale_regularization :", lambda1, 2, 6, "" ); gd.addNumericField( "translation_regularization :", lambda2, 2, 6, "" ); gd.addNumericField( "smoothness_regularization :", neighborWeight, 2, 6, "" ); gd.showDialog(); if ( gd.wasCanceled() ) return null; final List< Layer > layers = layerset.getLayers().subList( gd.getNextChoiceIndex(), gd.getNextChoiceIndex() + 1 ); System.out.println( layers.size() ); scale = gd.getNextNumber(); numCoefficients = ( int )gd.getNextNumber(); radius = ( int ) gd.getNextNumber(); iterations = ( int )gd.getNextNumber(); lambda1 = gd.getNextNumber(); lambda2 = gd.getNextNumber(); neighborWeight = gd.getNextNumber(); try { run( layers, radius, scale, numCoefficients, lambda1, lambda2, neighborWeight, getRoi( layerset ) ); } catch ( final InterruptedException e ) { Utils.log( "Match intensities interrupted." ); e.printStackTrace( System.out ); } catch ( final ExecutionException e ) { Utils.log( "Match intenities ExecutiuonException occurred:" ); e.printStackTrace( System.out ); } return null; } @Override public boolean applies( final Object ob ) { return true; } /** * @param layers * @param radius * @param scale * @param numCoefficients * @param lambda1 * @param lambda2 * @param neighborWeight * @param roi */ public < M extends Model< M > & Affine1D< M > > void run( final List< Layer > layers, final int radius, final double scale, final int numCoefficients, final double lambda1, final double lambda2, final double neighborWeight, final Rectangle roi ) throws InterruptedException, ExecutionException { final int firstLayerIndex = layerset.getLayerIndex( layers.get( 0 ).getId() ); final int lastLayerIndex = layerset.getLayerIndex( layers.get( layers.size() - 1 ).getId() ); // final PointMatchFilter filter = new RansacRegressionFilter(); final PointMatchFilter filter = new RansacRegressionReduceFilter(); /* collect patches */ Utils.log( "Collecting patches ... " ); final ArrayList< Patch > patches = new ArrayList< Patch >(); for ( final Layer layer : layers ) patches.addAll( ( Collection )layer.getDisplayables( Patch.class, roi ) ); /* delete existing intensity coefficients */ Utils.log( "Clearing existing intensity maps ... " ); for ( final Patch p : patches ) p.clearIntensityMap(); /* generate coefficient tiles for all patches * TODO consider offering alternative models */ final HashMap< Patch, ArrayList< Tile< ? extends M > > > coefficientsTiles = ( HashMap ) generateCoefficientsTiles( patches, new InterpolatedAffineModel1D< InterpolatedAffineModel1D< AffineModel1D, TranslationModel1D >, IdentityModel >( new InterpolatedAffineModel1D< AffineModel1D, TranslationModel1D >( new AffineModel1D(), new TranslationModel1D(), lambda1 ), new IdentityModel(), lambda2 ), numCoefficients * numCoefficients ); /* completed patches */ final HashSet< Patch > completedPatches = new HashSet< Patch >(); /* collect patch pairs */ Utils.log( "Collecting patch pairs ... " ); final ArrayList< ValuePair< Patch, Patch > > patchPairs = new ArrayList< ValuePair< Patch, Patch > >(); for ( final Patch p1 : patches ) { completedPatches.add( p1 ); final Rectangle box1 = p1.getBoundingBox().intersection( roi ); final ArrayList< Patch > p2s = new ArrayList< Patch >(); /* across adjacent layers */ final int layerIndex = layerset.getLayerIndex( p1.getLayer().getId() ); for ( int i = Math.max( firstLayerIndex, layerIndex - radius ); i <= Math.min( lastLayerIndex, layerIndex + radius ); ++i ) { final Layer layer = layerset.getLayer( i ); if ( layer != null ) p2s.addAll( ( Collection ) layer.getDisplayables( Patch.class, box1 ) ); } for ( final Patch p2 : p2s ) { /* * if this patch had been processed earlier, all matches are * already in */ if ( completedPatches.contains( p2 ) ) continue; patchPairs.add( new ValuePair< Patch, Patch >( p1, p2 ) ); } } final int numThreads = Integer.parseInt( layerset.getProperty( "n_mipmap_threads", Integer.toString( Runtime.getRuntime().availableProcessors() ) ) ); Utils.log( "Matching intensities using " + numThreads + " threads ... " ); final ExecutorService exec = Executors.newFixedThreadPool( numThreads ); final ArrayList< Future< ? > > futures = new ArrayList< Future< ? > >(); for ( final ValuePair< Patch, Patch > patchPair : patchPairs ) { futures.add( exec.submit( new Matcher( roi, patchPair, ( HashMap )coefficientsTiles, filter, scale, numCoefficients ) ) ); } for ( final Future< ? > future : futures ) future.get(); /* connect tiles within patches */ Utils.log( "Connecting coefficient tiles in the same patch ... " ); for ( final Patch p1 : completedPatches ) { /* get the coefficient tiles */ final ArrayList< Tile< ? extends M > > p1CoefficientsTiles = coefficientsTiles.get( p1 ); for ( int y = 1; y < numCoefficients; ++y ) { final int yr = numCoefficients * y; final int yr1 = yr - numCoefficients; for ( int x = 0; x < numCoefficients; ++x ) { identityConnect( p1CoefficientsTiles.get( yr1 + x ), p1CoefficientsTiles.get( yr + x ), neighborWeight ); } } for ( int y = 0; y < numCoefficients; ++y ) { final int yr = numCoefficients * y; for ( int x = 1; x < numCoefficients; ++x ) { final int yrx = yr + x; identityConnect( p1CoefficientsTiles.get( yrx ), p1CoefficientsTiles.get( yrx - 1 ), neighborWeight ); } } } /* optimize */ Utils.log( "Optimizing ... " ); final TileConfiguration tc = new TileConfiguration(); for ( final ArrayList< Tile< ? extends M > > coefficients : coefficientsTiles.values() ) { // for ( final Tile< ? > t : coefficients ) // if ( t.getMatches().size() == 0 ) // IJ.log( "bang" ); tc.addTiles( coefficients ); } try { tc.optimize( 0.01f, iterations, iterations, 0.75f ); } catch ( final NotEnoughDataPointsException e ) { // TODO Auto-generated catch block e.printStackTrace(); } catch ( final IllDefinedDataPointsException e ) { // TODO Auto-generated catch block e.printStackTrace(); } /* save coefficients */ final double[] ab = new double[ 2 ]; final FSLoader loader = ( FSLoader ) layerset.getProject().getLoader(); final String itsDir = loader.getUNUIdFolder() + "trakem2.its/"; for ( final Entry< Patch, ArrayList< Tile< ? extends M > > > entry : coefficientsTiles.entrySet() ) { final FloatProcessor as = new FloatProcessor( numCoefficients, numCoefficients ); final FloatProcessor bs = new FloatProcessor( numCoefficients, numCoefficients ); final Patch p = entry.getKey(); final double min = p.getMin(); final double max = p.getMax(); final ArrayList< Tile< ? extends M > > tiles = entry.getValue(); for ( int i = 0; i < numCoefficients * numCoefficients; ++i ) { final Tile< ? extends M > t = tiles.get( i ); final Affine1D< ? > affine = t.getModel(); affine.toArray( ab ); /* coefficients mapping into existing [min, max] */ as.setf( i, ( float ) ab[ 0 ] ); bs.setf( i, ( float ) ( ( max - min ) * ab[ 1 ] + min - ab[ 0 ] * min ) ); } final ImageStack coefficientsStack = new ImageStack( numCoefficients, numCoefficients ); coefficientsStack.addSlice( as ); coefficientsStack.addSlice( bs ); final String itsPath = itsDir + FSLoader.createIdPath( Long.toString( p.getId() ), "it", ".tif" ); new File( itsPath ).getParentFile().mkdirs(); IJ.saveAs( new ImagePlus( "", coefficientsStack ), "tif", itsPath ); } /* update mipmaps */ for ( final Patch p : patches ) p.getProject().getLoader().decacheImagePlus(p.getId()); final ArrayList< Future< Boolean > > mipmapFutures = new ArrayList< Future< Boolean > >(); for ( final Patch p : patches ) mipmapFutures.add( p.updateMipMaps() ); for ( final Future< Boolean > f : mipmapFutures ) f.get(); Utils.log( "Matching intensities done." ); } final static public void main( final String... args ) { new ImageJ(); final Project project = Project.openFSProject( "/home/saalfeld/tmp/intensity-corrected/elastic.xml", true ); final MatchIntensities matcher = new MatchIntensities(); matcher.invoke( project.getRootLayerSet() ); } }